@@ -184,9 +184,9 @@ def train(
184184 self ,
185185 X_train : np .ndarray ,
186186 y_train : np .ndarray ,
187- X_val : np .ndarray ,
188- y_val : np .ndarray ,
189187 training_config : TrainingConfig ,
188+ X_val : Optional [np .ndarray ] = None ,
189+ y_val : Optional [np .ndarray ] = None ,
190190 verbose : bool = False ,
191191 ) -> None :
192192 """Train the classifier using PyTorch Lightning.
@@ -224,7 +224,14 @@ def train(
224224 """
225225 # Input validation
226226 X_train , y_train = self ._check_XY (X_train , y_train )
227- X_val , y_val = self ._check_XY (X_val , y_val )
227+
228+ if X_val is not None :
229+ assert y_val is not None , "y_val must be provided if X_val is provided."
230+ if y_val is not None :
231+ assert X_val is not None , "X_val must be provided if y_val is provided."
232+
233+ if X_val is not None and y_val is not None :
234+ X_val , y_val = self ._check_XY (X_val , y_val )
228235
229236 if (
230237 X_train ["categorical_variables" ] is not None
@@ -277,40 +284,43 @@ def train(
277284 texts = X_train ["text" ],
278285 categorical_variables = X_train ["categorical_variables" ], # None if no cat vars
279286 tokenizer = self .tokenizer ,
280- labels = y_train ,
281- ragged_multilabel = self .ragged_multilabel ,
282- )
283- val_dataset = TextClassificationDataset (
284- texts = X_val ["text" ],
285- categorical_variables = X_val ["categorical_variables" ], # None if no cat vars
286- tokenizer = self .tokenizer ,
287- labels = y_val ,
287+ labels = y_train .tolist (),
288288 ragged_multilabel = self .ragged_multilabel ,
289289 )
290-
291290 train_dataloader = train_dataset .create_dataloader (
292291 batch_size = training_config .batch_size ,
293292 num_workers = training_config .num_workers ,
294293 shuffle = True ,
295294 ** training_config .dataloader_params if training_config .dataloader_params else {},
296295 )
297- val_dataloader = val_dataset .create_dataloader (
298- batch_size = training_config .batch_size ,
299- num_workers = training_config .num_workers ,
300- shuffle = False ,
301- ** training_config .dataloader_params if training_config .dataloader_params else {},
302- )
296+
297+ if X_val is not None and y_val is not None :
298+ val_dataset = TextClassificationDataset (
299+ texts = X_val ["text" ],
300+ categorical_variables = X_val ["categorical_variables" ], # None if no cat vars
301+ tokenizer = self .tokenizer ,
302+ labels = y_val ,
303+ ragged_multilabel = self .ragged_multilabel ,
304+ )
305+ val_dataloader = val_dataset .create_dataloader (
306+ batch_size = training_config .batch_size ,
307+ num_workers = training_config .num_workers ,
308+ shuffle = False ,
309+ ** training_config .dataloader_params if training_config .dataloader_params else {},
310+ )
311+ else :
312+ val_dataloader = None
303313
304314 # Setup trainer
305315 callbacks = [
306316 ModelCheckpoint (
307- monitor = "val_loss" ,
317+ monitor = "val_loss" if val_dataloader is not None else "train_loss" ,
308318 save_top_k = 1 ,
309319 save_last = False ,
310320 mode = "min" ,
311321 ),
312322 EarlyStopping (
313- monitor = "val_loss" ,
323+ monitor = "val_loss" if val_dataloader is not None else "train_loss" ,
314324 patience = training_config .patience_early_stopping ,
315325 mode = "min" ,
316326 ),
@@ -442,9 +452,9 @@ def _check_Y(self, Y):
442452
443453 else :
444454 assert isinstance (Y , np .ndarray ), "Y must be a numpy array of shape (N,) or (N,1)."
445- assert len ( Y . shape ) == 1 or (
446- len (Y .shape ) == 2 and Y .shape [ 1 ] == 1
447- ), "Y must be a numpy array of shape (N,) or (N,1 )."
455+ assert (
456+ len (Y .shape ) == 1 or len ( Y .shape ) == 2
457+ ), "Y must be a numpy array of shape (N,) or (N, num_labels )."
448458
449459 try :
450460 Y = Y .astype (int )
0 commit comments