A PyTorch implementation of Ξ²-Variational Autoencoder (Ξ²-VAE) for learning interpretable and disentangled latent representations. This implementation allows you to train models that separate underlying factors of variation in your data.
- Webpage: Ξ²-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework
- Dataset: CelebA
Ξ²-VAE extends the standard VAE by introducing a hyperparameter Ξ² that controls the trade-off between reconstruction quality and disentanglement in the latent space. Higher Ξ² values encourage more disentangled representations, where individual latent dimensions correspond to independent factors of variation.
- Complete Ξ²-VAE Implementation: Encoder, decoder, and reparameterization trick
- Flexible Architecture: Customizable hidden dimensions and latent space size
- Training Pipeline: Full training loop with validation and checkpointing
- Visualization Tools:
- Reconstruction comparison
- Random sampling from latent space
- Latent space traversal (manipulate individual dimensions)
- Latent interpolation between images
- TensorBoard Integration: Real-time training monitoring
- Custom Dataset Support: Easy integration with your own image datasets
pip install torch torchvision matplotlib numpy tqdm pillow tensorboard- Prepare Your Data: Place your images in a folder structure:
/CelebA
βββ /img_align_celeba
βββ image1.jpg
βββ image2.jpg
βββ ...
- Configure Training: Modify the configuration dictionary:
config = {
'data_path': './CelebA/img_align_celeba',
'batch_size': 32,
'img_size': 64,
'latent_dim': 128,
'hidden_dims': [32, 64, 128, 256],
'beta': 4.0, # Disentanglement strength
'lr': 1e-4,
'epochs': 50,
}-
Run Training: Execute the notebook cells sequentially or convert to a Python script.
-
Monitor Progress: View training metrics with TensorBoard:
tensorboard --logdir=./runs- Convolutional layers with BatchNorm and LeakyReLU
- Outputs mean (ΞΌ) and log-variance (log ΟΒ²) for latent distribution
- Default: 4 conv layers β 128D latent space
- Transposed convolutional layers
- Reconstructs images from latent codes
- Sigmoid activation for output normalization
L = Reconstruction Loss + Ξ² Γ KL Divergence
- Reconstruction Loss: MSE between input and output
- KL Divergence: Regularization term encouraging Gaussian latent distribution
- Ξ²: Controls disentanglement (typical range: 1-10)
| Parameter | Description | Default | Tuning Tips |
|---|---|---|---|
beta |
Disentanglement strength | 4.0 | Higher β more disentangled but worse reconstruction |
latent_dim |
Latent space dimensions | 128 | More dims β more capacity but harder to interpret |
hidden_dims |
Encoder/decoder layer sizes | [32,64,128,256] | Adjust based on image complexity |
learning_rate |
Optimizer learning rate | 1e-4 | Reduce if training is unstable |
Compare original images with their reconstructions to evaluate model performance.
Manipulate individual latent dimensions to discover learned features:
- Dimension 5 might control lighting
- Dimension 10 might control rotation
- Dimension 15 might control expression
Smoothly transition between two images by interpolating in latent space.
βββ Disentanglement_Bvae.ipynb # Main implementation notebook
βββ CelebA/ # Dataset directory
βββ checkpoints/ # Saved models
β βββ best_model.pt
β βββ checkpoint_epoch_*.pt
βββ runs/ # TensorBoard logs
βββ README.md # This file
from torch.utils.data import Dataset
class CustomImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
# Load your images here
def __getitem__(self, idx):
# Return transformed image
passcheckpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()# Sample from prior distribution
num_samples = 16
samples = model.sample(num_samples, device)- Start with Ξ²=1: Train a standard VAE first, then gradually increase Ξ²
- Monitor KL Divergence: Should stabilize after initial epochs
- Adjust Learning Rate: Use ReduceLROnPlateau scheduler for adaptive learning
- Checkpoint Regularly: Save models every 10 epochs
- Visualize Early: Check reconstructions after 5-10 epochs
Poor Reconstructions:
- Decrease Ξ² value
- Increase latent dimensions
- Train for more epochs
- Check learning rate
Not Disentangled:
- Increase Ξ² gradually (4 β 6 β 8)
- Ensure diverse training data
- Increase model capacity
- Train longer
Training Instability:
- Reduce learning rate
- Add gradient clipping
- Check data normalization
- Reduce batch size
- Ξ²-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework (Higgins et al., 2017)
- Understanding disentangling in Ξ²-VAE (Burgess et al., 2018)
- Variational Autoencoders | Generative AI Animated. (Deepia, 2024)
If you use this implementation in your research, please cite:
@misc{bvae-implementation,
author = Vishva MV,
title = {Ξ²-VAE Implementation for Disentangled Representations},
year = {2025},
publisher = {GitHub},
url = {https://github.com/Vishva2003/beta-vae}
}Contributions are welcome! Please feel free to submit a Pull Request. For major changes:
- Fork the repository
- Create your feature branch (
git checkout -b feature/AmazingFeature) - Commit your changes (
git commit -m 'Add some AmazingFeature') - Push to the branch (
git push origin feature/AmazingFeature) - Open a Pull Request
This project is licensed under the MIT License - see the LICENSE file for details.
- Original Ξ²-VAE authors for the groundbreaking research
- Community contributors and researchers in disentangled representation learning
For questions or feedback, please open an issue on GitHub or contact dev.vishvamv@mail.com
Happy Learning! π If you find this useful, please consider starring β the repository!