Skip to content

Commit 927a5e7

Browse files
fix: check_Y problem of indexes
1 parent 4ca1807 commit 927a5e7

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,9 @@ def _check_categorical_variables(self, X: np.ndarray) -> None:
410410
f"Columns {1} to {X.shape[1] - 1} of X_train must be castable in integer format."
411411
)
412412

413-
for j in range(1, X.shape[1]):
414-
max_cat_value = categorical_variables.max()
415-
if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j - 1]:
413+
for j in range(X.shape[1] - 1):
414+
max_cat_value = categorical_variables[:, j].max()
415+
if max_cat_value >= self.categorical_var_net.categorical_vocabulary_sizes[j]:
416416
raise ValueError(
417417
f"Categorical variable at index {j} has value {max_cat_value} which exceeds the vocabulary size of {self.categorical_var_net.categorical_vocabulary_sizes[j]}."
418418
)

0 commit comments

Comments
 (0)