@@ -56,320 +56,4 @@ def compute_loss(self, batch_data, criterion):
5656 output , hidden , cell = self .forward (ipt , hidden , cell )
5757 output = output .view (output .size (0 ) * output .size (1 ), - 1 )
5858 loss = criterion (output , tgt .view (- 1 ))
59- return loss
60-
61- ## Define SmilesRnnSampler
62- # class SMILESSampler:
63- # """
64- # Samples molecules from an RNN smiles language model
65- # """
66- # def __init__(self, device: str, batch_size=64) -> None:
67- # """
68- # Args:
69- # device: cpu | cuda
70- # batch_size: number of concurrent samples to generate
71- # """
72- # self.device = device
73- # self.batch_size = batch_size
74- # self.sd = SmilesCharDictionary()
75-
76- # def sample(self, model: LSTM, num_to_sample: int, max_seq_len=100):
77- # """
78-
79- # Args:
80- # model: RNN to sample from
81- # num_to_sample: number of samples to produce
82- # max_seq_len: maximum length of the samples
83- # batch_size: number of concurrent samples to generate
84-
85- # Returns: a list of SMILES string, with no beginning nor end symbols
86-
87- # """
88- # sampler = ActionSampler(max_batch_size=self.batch_size, max_seq_length=max_seq_len, device=self.device)
89-
90- # model.eval()
91- # with torch.no_grad():
92- # indices = sampler.sample(model, num_samples=num_to_sample)
93- # return self.sd.matrix_to_smiles(indices)
94-
95- # define SmilesRnnTrainer
96-
97- # class SmilesRnnTrainer:
98- # def __init__(self, model, criteria, optimizer, device, log_dir=None, clip_gradients=True) -> None:
99- # self.model = model.to(device)
100- # self.criteria = [c.to(device) for c in criteria]
101- # self.optimizer = optimizer
102- # self.device = device
103- # self.log_dir = log_dir
104- # self.clip_gradients = clip_gradients
105-
106- # def process_batch(self, batch):
107-
108- # # ship data to device
109- # inp, tgt = batch
110- # inp = inp.to(self.device)
111- # tgt = tgt.to(self.device)
112-
113- # # process data
114- # batch_size = inp.size(0)
115- # hidden = self.model.init_hidden(inp.size(0), self.device)
116- # output, hidden = self.model(inp, hidden)
117- # output = output.view(output.size(0) * output.size(1), -1)
118- # loss = self.criteria[0](output, tgt.view(-1))
119- # return loss, batch_size
120-
121- # def train_on_batch(self, batch):
122-
123- # # setup model for training
124- # self.model.train()
125- # self.model.zero_grad()
126-
127- # # forward / backward
128- # loss, size = self.process_batch(batch)
129- # loss.backward()
130-
131- # # optimize
132- # if self.clip_gradients:
133- # nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
134- # self.optimizer.step()
135-
136- # return loss.item(), size
137-
138- # def test_on_batch(self, batch):
139-
140- # # setup model for evaluation
141- # self.model.eval()
142-
143- # # forward
144- # loss, size = self.process_batch(batch)
145-
146- # return loss.item(), size
147-
148- # def validate(self, data_loader, n_molecule):
149- # """Runs validation and reports the average loss"""
150- # valid_losses = []
151- # with torch.no_grad():
152- # for batch in data_loader:
153- # loss, size = self.test_on_batch(batch)
154- # valid_losses += [loss]
155- # return np.array(valid_losses).mean()
156-
157- # def train_extra_log(self, n_molecules):
158- # pass
159-
160- # def valid_extra_log(self, n_molecules):
161- # pass
162-
163- # def fit(self, training_data, test_data, n_epochs, batch_size, print_every,
164- # valid_every, num_workers=0):
165- # training_round = _ModelTrainingRound(self, training_data, test_data, n_epochs, batch_size, print_every,
166- # valid_every, num_workers)
167- # return training_round.run()
168-
169-
170- # class _ModelTrainingRound:
171- # """
172- # Performs one round of model training.
173-
174- # Is a separate class from ModelTrainer to allow for more modular functions without too many parameters.
175- # This class is not to be used outside of ModelTrainer.
176- # """
177- # class EarlyStopNecessary(Exception):
178- # pass
179-
180- # def __init__(self, model_trainer: SmilesRnnTrainer, training_data, test_data, n_epochs, batch_size, print_every,
181- # valid_every, num_workers=0) -> None:
182- # self.model_trainer = model_trainer
183- # self.training_data = training_data
184- # self.test_data = test_data
185- # self.n_epochs = n_epochs
186- # self.batch_size = batch_size
187- # self.print_every = print_every
188- # self.valid_every = valid_every
189- # self.num_workers = num_workers
190-
191- # self.start_time = time.time()
192- # self.unprocessed_train_losses: List[float] = []
193- # self.all_train_losses: List[float] = []
194- # self.all_valid_losses: List[float] = []
195- # self.n_molecules_so_far = 0
196- # self.has_run = False
197- # self.min_valid_loss = np.inf
198- # self.min_avg_train_loss = np.inf
199-
200- # def run(self):
201- # if self.has_run:
202- # raise Exception('_ModelTrainingRound.train() can be called only once.')
203-
204- # try:
205- # for epoch_index in range(1, self.n_epochs + 1):
206- # self._train_one_epoch(epoch_index)
207-
208- # self._validation_on_final_model()
209- # except _ModelTrainingRound.EarlyStopNecessary:
210- # logger.error('Probable explosion during training. Stopping now.')
211-
212- # self.has_run = True
213- # return self.all_train_losses, self.all_valid_losses
214-
215- # def _train_one_epoch(self, epoch_index: int):
216- # logger.info(f'EPOCH {epoch_index}')
217-
218- # # shuffle at every epoch
219- # data_loader = DataLoader(self.training_data,
220- # batch_size=self.batch_size,
221- # shuffle=True,
222- # num_workers=self.num_workers,
223- # pin_memory=True)
224-
225- # epoch_t0 = time.time()
226- # self.unprocessed_train_losses.clear()
227-
228- # for batch_index, batch in enumerate(data_loader):
229- # self._train_one_batch(batch_index, batch, epoch_index, epoch_t0)
230-
231- # def _train_one_batch(self, batch_index, batch, epoch_index, train_t0):
232- # loss, size = self.model_trainer.train_on_batch(batch)
233-
234- # self.unprocessed_train_losses += [loss]
235- # self.n_molecules_so_far += size
236-
237- # # report training progress?
238- # if batch_index > 0 and batch_index % self.print_every == 0:
239- # self._report_training_progress(batch_index, epoch_index, epoch_start=train_t0)
240-
241- # # report validation progress?
242- # if batch_index >= 0 and batch_index % self.valid_every == 0:
243- # self._report_validation_progress(epoch_index)
244-
245- # def _report_training_progress(self, batch_index, epoch_index, epoch_start):
246- # mols_sec = self._calculate_mols_per_second(batch_index, epoch_start)
247-
248- # # Update train losses by processing all losses since last time this function was executed
249- # avg_train_loss = np.array(self.unprocessed_train_losses).mean()
250- # self.all_train_losses += avg_train_loss
251- # self.unprocessed_train_losses.clear()
252-
253- # logger.info(
254- # 'TRAIN | '
255- # f'elapsed: {time_since(self.start_time)} | '
256- # f'epoch|batch : {epoch_index}|{batch_index} ({self._get_overall_progress():.1f}%) | '
257- # f'molecules: {self.n_molecules_so_far} | '
258- # f'mols/sec: {mols_sec:.2f} | '
259- # f'train_loss: {avg_train_loss:.4f}')
260- # self.model_trainer.train_extra_log(self.n_molecules_so_far)
261-
262- # self._check_early_stopping_train_loss(avg_train_loss)
263-
264- # def _calculate_mols_per_second(self, batch_index, epoch_start):
265- # """
266- # Calculates the speed so far in the current epoch.
267- # """
268- # train_time_in_current_epoch = time.time() - epoch_start
269- # processed_batches = batch_index + 1
270- # molecules_in_current_epoch = self.batch_size * processed_batches
271- # return molecules_in_current_epoch / train_time_in_current_epoch
272-
273- # def _report_validation_progress(self, epoch_index):
274- # avg_valid_loss = self._validate_current_model()
275-
276- # self._log_validation_step(epoch_index, avg_valid_loss)
277- # self._check_early_stopping_validation(avg_valid_loss)
278-
279- # # save model?
280- # if self.model_trainer.log_dir:
281- # if avg_valid_loss <= min(self.all_valid_losses):
282- # self._save_current_model(self.model_trainer.log_dir, epoch_index, avg_valid_loss)
283-
284- # def _validate_current_model(self):
285- # """
286- # Validate the current model.
287-
288- # Returns: Validation loss.
289- # """
290- # test_loader = DataLoader(self.test_data,
291- # batch_size=self.batch_size,
292- # shuffle=False,
293- # num_workers=self.num_workers,
294- # pin_memory=True)
295- # avg_valid_loss = self.model_trainer.validate(test_loader, self.n_molecules_so_far)
296- # self.all_valid_losses += [avg_valid_loss]
297- # return avg_valid_loss
298-
299- # def _log_validation_step(self, epoch_index, avg_valid_loss):
300- # """
301- # Log the information about the validation step.
302- # """
303- # logger.info(
304- # 'VALID | '
305- # f'elapsed: {time_since(self.start_time)} | '
306- # f'epoch: {epoch_index}/{self.n_epochs} ({self._get_overall_progress():.1f}%) | '
307- # f'molecules: {self.n_molecules_so_far} | '
308- # f'valid_loss: {avg_valid_loss:.4f}')
309- # self.model_trainer.valid_extra_log(self.n_molecules_so_far)
310- # logger.info('')
311-
312- # def _get_overall_progress(self):
313- # total_mols = self.n_epochs * len(self.training_data)
314- # return 100. * self.n_molecules_so_far / total_mols
315-
316- # def _validation_on_final_model(self):
317- # """
318- # Run validation for the final model and save it.
319- # """
320- # valid_loss = self._validate_current_model()
321- # logger.info(
322- # 'VALID | FINAL_MODEL | '
323- # f'elapsed: {time_since(self.start_time)} | '
324- # f'molecules: {self.n_molecules_so_far} | '
325- # f'valid_loss: {valid_loss:.4f}')
326-
327- # if self.model_trainer.log_dir:
328- # self._save_model(self.model_trainer.log_dir, 'final', valid_loss)
329-
330- # def _save_current_model(self, base_dir, epoch, valid_loss):
331- # """
332- # Delete previous versions of the model and save the current one.
333- # """
334- # for f in glob(os.path.join(base_dir, 'model_*')):
335- # os.remove(f)
336-
337- # self._save_model(base_dir, epoch, valid_loss)
338-
339- # def _save_model(self, base_dir, info, valid_loss):
340- # """
341- # Save a copy of the model with format:
342- # model_{info}_{valid_loss}
343- # """
344- # base_name = f'model_{info}_{valid_loss:.3f}'
345- # logger.info(base_name)
346- # save_model(self.model_trainer.model, base_dir, base_name)
347-
348- # def _check_early_stopping_train_loss(self, avg_train_loss):
349- # """
350- # This function checks whether the training has exploded by verifying if the avg training loss
351- # is more than 10 times the minimal loss so far.
352-
353- # If this is the case, a EarlyStopNecessary exception is raised.
354- # """
355- # threshold = 10 * self.min_avg_train_loss
356- # if avg_train_loss > threshold:
357- # raise _ModelTrainingRound.EarlyStopNecessary()
358-
359- # # update the min train loss if necessary
360- # if avg_train_loss < self.min_avg_train_loss:
361- # self.min_avg_train_loss = avg_train_loss
362-
363- # def _check_early_stopping_validation(self, avg_valid_loss):
364- # """
365- # This function checks whether the training has exploded by verifying if the validation loss
366- # has more than doubled compared to the minimum validation loss so far.
367-
368- # If this is the case, a EarlyStopNecessary exception is raised.
369- # """
370- # threshold = 2 * self.min_valid_loss
371- # if avg_valid_loss > threshold:
372- # raise _ModelTrainingRound.EarlyStopNecessary()
373-
374- # if avg_valid_loss < self.min_valid_loss:
375- # self.min_valid_loss = avg_valid_loss
59+ return loss
0 commit comments