Skip to content

Commit b923c8d

Browse files
chore: X_val, y_val are optional
1 parent c9c5790 commit b923c8d

1 file changed

Lines changed: 33 additions & 23 deletions

File tree

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ def train(
184184
self,
185185
X_train: np.ndarray,
186186
y_train: np.ndarray,
187-
X_val: np.ndarray,
188-
y_val: np.ndarray,
189187
training_config: TrainingConfig,
188+
X_val: Optional[np.ndarray] = None,
189+
y_val: Optional[np.ndarray] = None,
190190
verbose: bool = False,
191191
) -> None:
192192
"""Train the classifier using PyTorch Lightning.
@@ -224,7 +224,14 @@ def train(
224224
"""
225225
# Input validation
226226
X_train, y_train = self._check_XY(X_train, y_train)
227-
X_val, y_val = self._check_XY(X_val, y_val)
227+
228+
if X_val is not None:
229+
assert y_val is not None, "y_val must be provided if X_val is provided."
230+
if y_val is not None:
231+
assert X_val is not None, "X_val must be provided if y_val is provided."
232+
233+
if X_val is not None and y_val is not None:
234+
X_val, y_val = self._check_XY(X_val, y_val)
228235

229236
if (
230237
X_train["categorical_variables"] is not None
@@ -277,40 +284,43 @@ def train(
277284
texts=X_train["text"],
278285
categorical_variables=X_train["categorical_variables"], # None if no cat vars
279286
tokenizer=self.tokenizer,
280-
labels=y_train,
281-
ragged_multilabel=self.ragged_multilabel,
282-
)
283-
val_dataset = TextClassificationDataset(
284-
texts=X_val["text"],
285-
categorical_variables=X_val["categorical_variables"], # None if no cat vars
286-
tokenizer=self.tokenizer,
287-
labels=y_val,
287+
labels=y_train.tolist(),
288288
ragged_multilabel=self.ragged_multilabel,
289289
)
290-
291290
train_dataloader = train_dataset.create_dataloader(
292291
batch_size=training_config.batch_size,
293292
num_workers=training_config.num_workers,
294293
shuffle=True,
295294
**training_config.dataloader_params if training_config.dataloader_params else {},
296295
)
297-
val_dataloader = val_dataset.create_dataloader(
298-
batch_size=training_config.batch_size,
299-
num_workers=training_config.num_workers,
300-
shuffle=False,
301-
**training_config.dataloader_params if training_config.dataloader_params else {},
302-
)
296+
297+
if X_val is not None and y_val is not None:
298+
val_dataset = TextClassificationDataset(
299+
texts=X_val["text"],
300+
categorical_variables=X_val["categorical_variables"], # None if no cat vars
301+
tokenizer=self.tokenizer,
302+
labels=y_val,
303+
ragged_multilabel=self.ragged_multilabel,
304+
)
305+
val_dataloader = val_dataset.create_dataloader(
306+
batch_size=training_config.batch_size,
307+
num_workers=training_config.num_workers,
308+
shuffle=False,
309+
**training_config.dataloader_params if training_config.dataloader_params else {},
310+
)
311+
else:
312+
val_dataloader = None
303313

304314
# Setup trainer
305315
callbacks = [
306316
ModelCheckpoint(
307-
monitor="val_loss",
317+
monitor="val_loss" if val_dataloader is not None else "train_loss",
308318
save_top_k=1,
309319
save_last=False,
310320
mode="min",
311321
),
312322
EarlyStopping(
313-
monitor="val_loss",
323+
monitor="val_loss" if val_dataloader is not None else "train_loss",
314324
patience=training_config.patience_early_stopping,
315325
mode="min",
316326
),
@@ -442,9 +452,9 @@ def _check_Y(self, Y):
442452

443453
else:
444454
assert isinstance(Y, np.ndarray), "Y must be a numpy array of shape (N,) or (N,1)."
445-
assert len(Y.shape) == 1 or (
446-
len(Y.shape) == 2 and Y.shape[1] == 1
447-
), "Y must be a numpy array of shape (N,) or (N,1)."
455+
assert (
456+
len(Y.shape) == 1 or len(Y.shape) == 2
457+
), "Y must be a numpy array of shape (N,) or (N, num_labels)."
448458

449459
try:
450460
Y = Y.astype(int)

0 commit comments

Comments
 (0)