Skip to content

Latest commit

 

History

History
460 lines (337 loc) · 11.7 KB

File metadata and controls

460 lines (337 loc) · 11.7 KB

Multiclass Classification Tutorial

Learn how to build a multiclass sentiment classifier with 3 classes: negative, neutral, and positive.

Learning Objectives

By the end of this tutorial, you will be able to:

  • Handle multiclass classification problems (3+ classes)
  • Configure models for multiple output classes
  • Ensure reproducible results with proper seeding
  • Evaluate multiclass performance
  • Understand class distribution and balance

Prerequisites

  • Completion of {doc}basic_classification tutorial (recommended)
  • Basic understanding of classification
  • torchTextClassifiers installed

Overview

In this tutorial, we'll build a 3-class sentiment classifier that categorizes product reviews as:

  • Negative (class 0): Bad reviews
  • Neutral (class 1): Mixed or moderate reviews
  • Positive (class 2): Good reviews

Key Difference from Binary Classification:

  • Binary: 2 classes (positive/negative)
  • Multiclass: 3+ classes (negative/neutral/positive)

Complete Code

import os
import numpy as np
import torch
from pytorch_lightning import seed_everything
from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers
from torchTextClassifiers.tokenizers import WordPieceTokenizer

# Step 1: Set Seeds for Reproducibility
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True, warn_only=True)

# Step 2: Prepare Multi-class Data
X_train = np.array([
    # Negative (class 0)
    "This product is terrible and I hate it completely.",
    "Worst purchase ever. Total waste of money.",
    "Absolutely awful quality. Very disappointed.",
    "Poor service and terrible product quality.",
    "I regret buying this. Complete failure.",

    # Neutral (class 1)
    "The product is okay, nothing special though.",
    "It works but could be better designed.",
    "Average quality for the price point.",
    "Not bad but not great either.",
    "It's fine, meets basic expectations.",

    # Positive (class 2)
    "Excellent product! Highly recommended!",
    "Amazing quality and great customer service.",
    "Perfect! Exactly what I was looking for.",
    "Outstanding value and excellent performance.",
    "Love it! Will definitely buy again."
])

y_train = np.array([0, 0, 0, 0, 0,  # negative
                    1, 1, 1, 1, 1,  # neutral
                    2, 2, 2, 2, 2]) # positive

# Validation data
X_val = np.array([
    "Bad quality, not recommended.",
    "It's okay, does the job.",
    "Great product, very satisfied!"
])
y_val = np.array([0, 1, 2])

# Test data
X_test = np.array([
    "This is absolutely horrible!",
    "It's an average product, nothing more.",
    "Fantastic! Love every aspect of it!",
    "Really poor design and quality.",
    "Works well, good value for money.",
    "Outstanding product with amazing features!"
])
y_test = np.array([0, 1, 2, 0, 1, 2])

# Step 3: Create and Train Tokenizer
tokenizer = WordPieceTokenizer(vocab_size=5000, output_dim=128)
tokenizer.train(X_train.tolist())

# Step 4: Configure Model for 3 Classes
model_config = ModelConfig(
    embedding_dim=64,
    num_classes=3  # KEY: 3 classes for multiclass
)

# Step 5: Create Classifier
classifier = torchTextClassifiers(
    tokenizer=tokenizer,
    model_config=model_config
)

# Step 6: Train Model
training_config = TrainingConfig(
    num_epochs=30,
    batch_size=8,
    lr=1e-3,
    patience_early_stopping=7,
    num_workers=0,
    trainer_params={'deterministic': True}
)

classifier.train(
    X_train, y_train,
    training_config,
    X_val, y_val,
    verbose=True
)

# Step 7: Make Predictions
result = classifier.predict(X_test)
predictions = result["prediction"].squeeze().numpy()

# Step 8: Evaluate
accuracy = (predictions == y_test).mean()
print(f"Test accuracy: {accuracy:.3f}")

# Show results with class names
class_names = ["Negative", "Neutral", "Positive"]
for text, pred, true in zip(X_test, predictions, y_test):
    predicted = class_names[pred]
    actual = class_names[true]
    status = "✅" if pred == true else "❌"
    print(f"{status} Predicted: {predicted}, True: {actual}")
    print(f"   Text: {text}")

Step-by-Step Walkthrough

Step 1: Ensuring Reproducibility

For consistent results across runs, set seeds properly:

SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
seed_everything(SEED, workers=True)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True, warn_only=True)

Why this matters:

  • Makes experiments reproducible
  • Enables fair comparison of hyperparameters
  • Helps debug model behavior

:::{tip} Always set seeds when reporting results or comparing models! :::

Step 2: Preparing Multiclass Data

Unlike binary classification, you now have 3 classes:

y_train = np.array([0, 0, 0, 0, 0,  # negative (class 0)
                    1, 1, 1, 1, 1,  # neutral (class 1)
                    2, 2, 2, 2, 2]) # positive (class 2)

Important: Class labels should be:

  • Integers: 0, 1, 2, ... (not strings)
  • Continuous: Start from 0, no gaps (0, 1, 2 not 0, 2, 5)
  • Balanced: Ideally equal samples per class

Check class distribution:

print(f"Negative: {sum(y_train==0)}")
print(f"Neutral: {sum(y_train==1)}")
print(f"Positive: {sum(y_train==2)}")

Output:

Negative: 5
Neutral: 5
Positive: 5

:::{note} This example has perfectly balanced classes (5 samples each). Real datasets are often imbalanced. :::

Step 3-4: Model Configuration

The only difference from binary classification:

model_config = ModelConfig(
    embedding_dim=64,
    num_classes=3  # Change from 2 to 3
)

Under the hood:

  • Binary: Uses 2 output neurons + CrossEntropyLoss
  • Multiclass: Uses 3 output neurons + CrossEntropyLoss
  • The loss function handles both automatically!

Step 5-6: Training

Training is identical to binary classification:

classifier.train(
    X_train, y_train,
    training_config
    X_val, y_val,
)

Training process:

  1. Forward pass: Text → Embeddings → Logits (3 values)
  2. Loss calculation: Compare logits to true labels
  3. Backward pass: Compute gradients
  4. Update weights: Optimizer step
  5. Repeat for each batch/epoch

Step 7: Making Predictions

Predictions now return values in {0, 1, 2}:

result = classifier.predict(X_test)
predictions = result["prediction"].squeeze().numpy()
# Example: [0, 1, 2, 0, 1, 2]

Probability interpretation:

You can also get probabilities for each class:

probabilities = result["confidence"].squeeze().numpy()
# Shape: (num_samples, 3)
# Each row sums to 1.0

Step 8: Evaluation

For multiclass, use class names for clarity:

class_names = ["Negative", "Neutral", "Positive"]

for pred, true in zip(predictions, y_test):
    predicted_label = class_names[pred]
    true_label = class_names[true]
    print(f"Predicted: {predicted_label}, True: {true_label}")

Output:

✅ Predicted: Negative, True: Negative
   Text: This is absolutely horrible!

✅ Predicted: Neutral, True: Neutral
   Text: It's an average product, nothing more.

✅ Predicted: Positive, True: Positive
   Text: Fantastic! Love every aspect of it!

Advanced: Class Imbalance

Real datasets often have unbalanced classes:

# Imbalanced example
y_train = [0]*100 + [1]*20 + [2]*10  # 100:20:10 ratio

Solutions:

1. Class Weights

Weight the loss function to penalize minority class errors more:

from torch import nn

# Calculate class weights
class_counts = np.bincount(y_train)
class_weights = 1.0 / class_counts
class_weights = class_weights / class_weights.sum()  # Normalize

# Use weighted loss
training_config = TrainingConfig(
    ...
    loss=nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights))
)

2. Oversampling/Undersampling

Balance the dataset before training:

from sklearn.utils import resample

# Oversample minority classes or undersample majority class
# (Use before creating the classifier)

3. Data Augmentation

Generate synthetic samples for minority classes.

Evaluation Metrics

For multiclass problems, accuracy isn't enough. Use:

Confusion Matrix

from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(y_test, predictions)
print("Confusion Matrix:")
print(cm)
#       Pred 0  Pred 1  Pred 2
# True 0 [[  2       0       0]
# True 1  [  0       2       0]
# True 2  [  0       0       2]]

Classification Report

report = classification_report(
    y_test, predictions,
    target_names=["Negative", "Neutral", "Positive"]
)
print(report)

Output:

              precision    recall  f1-score   support

    Negative       1.00      1.00      1.00         2
     Neutral       1.00      1.00      1.00         2
    Positive       1.00      1.00      1.00         2

    accuracy                           1.00         6
   macro avg       1.00      1.00      1.00         6
weighted avg       1.00      1.00      1.00         6

Metrics explained:

  • Precision: Of predicted class X, how many were correct?
  • Recall: Of true class X, how many did we find?
  • F1-score: Harmonic mean of precision and recall
  • Support: Number of samples in each class

Extending to More Classes

For 5 classes (e.g., star ratings 1-5):

# Data with 5 classes
y_train = np.array([0, 1, 2, 3, 4, ...])  # 0=1-star, 4=5-star

# Model configuration
model_config = ModelConfig(
    embedding_dim=64,
    num_classes=5  # Change to 5
)

The same code works for any number of classes!

Common Issues

Issue: Poor Performance on Middle Classes

Problem: Neutral class has low accuracy

Solution:

  1. Collect more neutral examples
  2. Make the distinction clearer in your data
  3. Consider if neutral is necessary (binary might be better)

Issue: Model Always Predicts One Class

Symptoms: All predictions are class 0 or class 2

Solutions:

  1. Check class balance - might be too imbalanced
  2. Verify labels are correct (0, 1, 2 not 1, 2, 3)
  3. Lower learning rate
  4. Train for more epochs

Issue: Overfitting

Symptoms: High training accuracy, low test accuracy

Solutions:

  1. Reduce embedding_dim
  2. Add more training data
  3. Use stronger early stopping (lower patience)

Next Steps

Now that you understand multiclass classification:

  1. Add categorical features: Combine text with metadata
  2. Try multilabel classification: Multiple labels per sample
  3. Use explainability: See which words matter for each class
  4. Explore advanced architectures: Add attention mechanisms

Complete Working Example

Find the full code in the repository:

Summary

In this tutorial, you learned:

  • ✅ How to set up multiclass classification (3+ classes)
  • ✅ How to configure num_classes correctly
  • ✅ How to ensure reproducible results with proper seeding
  • ✅ How to check and handle class distribution
  • ✅ How to evaluate multiclass models with confusion matrices
  • ✅ How to handle class imbalance

You're now ready to tackle real-world multiclass problems!