@@ -130,8 +130,9 @@ def main():
130130 )
131131
132132 classifier .train (
133- X_train , y_train , X_val , y_val ,
133+ X_train , y_train ,
134134 training_config = training_config ,
135+ X_val = X_val , y_val = y_val ,
135136 verbose = True
136137 )
137138
@@ -172,13 +173,13 @@ def main():
172173 lr = 1e-3 ,
173174 patience_early_stopping = 7 ,
174175 num_workers = 0 ,
175- cpu_run = False , # Don't override accelerator from trainer_params
176176 trainer_params = advanced_trainer_params
177177 )
178178
179179 advanced_classifier .train (
180- X_train , y_train , X_val , y_val ,
180+ X_train , y_train ,
181181 training_config = advanced_training_config ,
182+ X_val = X_val , y_val = y_val ,
182183 verbose = True
183184 )
184185
@@ -206,14 +207,14 @@ def main():
206207 batch_size = 16 , # Larger batch size for CPU
207208 lr = 1e-3 ,
208209 patience_early_stopping = 3 ,
209- cpu_run = False , # Don't override accelerator from trainer_params
210210 num_workers = 0 , # No multiprocessing for CPU
211211 trainer_params = {'deterministic' : True , 'accelerator' : 'cpu' }
212212 )
213213
214214 cpu_classifier .train (
215- X_train , y_train , X_val , y_val ,
215+ X_train , y_train ,
216216 training_config = cpu_training_config ,
217+ X_val = X_val , y_val = y_val ,
217218 verbose = True
218219 )
219220
@@ -257,8 +258,9 @@ def main():
257258 )
258259
259260 custom_classifier .train (
260- X_train , y_train , X_val , y_val ,
261+ X_train , y_train ,
261262 training_config = custom_training_config ,
263+ X_val = X_val , y_val = y_val ,
262264 verbose = True
263265 )
264266
0 commit comments