Skip to content

Vishva2003/B-VAE_Implementation_for_Disentangled_Representations

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

16 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Ξ²-VAE: Learning Disentangled Representations

A PyTorch implementation of Ξ²-Variational Autoencoder (Ξ²-VAE) for learning interpretable and disentangled latent representations. This implementation allows you to train models that separate underlying factors of variation in your data.

Resources

πŸ“‹ Overview

Ξ²-VAE extends the standard VAE by introducing a hyperparameter Ξ² that controls the trade-off between reconstruction quality and disentanglement in the latent space. Higher Ξ² values encourage more disentangled representations, where individual latent dimensions correspond to independent factors of variation.

✨ Features

  • Complete Ξ²-VAE Implementation: Encoder, decoder, and reparameterization trick
  • Flexible Architecture: Customizable hidden dimensions and latent space size
  • Training Pipeline: Full training loop with validation and checkpointing
  • Visualization Tools:
    • Reconstruction comparison
    • Random sampling from latent space
    • Latent space traversal (manipulate individual dimensions)
    • Latent interpolation between images
  • TensorBoard Integration: Real-time training monitoring
  • Custom Dataset Support: Easy integration with your own image datasets

πŸš€ Quick Start

Prerequisites

pip install torch torchvision matplotlib numpy tqdm pillow tensorboard

Basic Usage

  1. Prepare Your Data: Place your images in a folder structure:
/CelebA
└── /img_align_celeba
    β”œβ”€β”€ image1.jpg
    β”œβ”€β”€ image2.jpg
    └── ...
  1. Configure Training: Modify the configuration dictionary:
config = {
    'data_path': './CelebA/img_align_celeba',
    'batch_size': 32,
    'img_size': 64,
    'latent_dim': 128,
    'hidden_dims': [32, 64, 128, 256],
    'beta': 4.0,  # Disentanglement strength
    'lr': 1e-4,
    'epochs': 50,
}
  1. Run Training: Execute the notebook cells sequentially or convert to a Python script.

  2. Monitor Progress: View training metrics with TensorBoard:

tensorboard --logdir=./runs

πŸ—οΈ Architecture

Encoder

  • Convolutional layers with BatchNorm and LeakyReLU
  • Outputs mean (ΞΌ) and log-variance (log σ²) for latent distribution
  • Default: 4 conv layers β†’ 128D latent space

Decoder

  • Transposed convolutional layers
  • Reconstructs images from latent codes
  • Sigmoid activation for output normalization

Loss Function

L = Reconstruction Loss + Ξ² Γ— KL Divergence
  • Reconstruction Loss: MSE between input and output
  • KL Divergence: Regularization term encouraging Gaussian latent distribution
  • Ξ²: Controls disentanglement (typical range: 1-10)

πŸ“Š Key Parameters

Parameter Description Default Tuning Tips
beta Disentanglement strength 4.0 Higher β†’ more disentangled but worse reconstruction
latent_dim Latent space dimensions 128 More dims β†’ more capacity but harder to interpret
hidden_dims Encoder/decoder layer sizes [32,64,128,256] Adjust based on image complexity
learning_rate Optimizer learning rate 1e-4 Reduce if training is unstable

🎨 Visualization Examples

1. Reconstruction Quality

Compare original images with their reconstructions to evaluate model performance.

2. Latent Space Traversal

Manipulate individual latent dimensions to discover learned features:

  • Dimension 5 might control lighting
  • Dimension 10 might control rotation
  • Dimension 15 might control expression

3. Interpolation

Smoothly transition between two images by interpolating in latent space.

πŸ“ Project Structure

β”œβ”€β”€ Disentanglement_Bvae.ipynb   # Main implementation notebook
β”œβ”€β”€ CelebA/                       # Dataset directory
β”œβ”€β”€ checkpoints/                  # Saved models
β”‚   β”œβ”€β”€ best_model.pt
β”‚   └── checkpoint_epoch_*.pt
β”œβ”€β”€ runs/                         # TensorBoard logs
└── README.md                     # This file

πŸ”§ Advanced Usage

Custom Dataset

from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        # Load your images here
    
    def __getitem__(self, idx):
        # Return transformed image
        pass

Loading Pre-trained Models

checkpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

Generating New Images

# Sample from prior distribution
num_samples = 16
samples = model.sample(num_samples, device)

πŸ“ˆ Training Tips

  1. Start with Ξ²=1: Train a standard VAE first, then gradually increase Ξ²
  2. Monitor KL Divergence: Should stabilize after initial epochs
  3. Adjust Learning Rate: Use ReduceLROnPlateau scheduler for adaptive learning
  4. Checkpoint Regularly: Save models every 10 epochs
  5. Visualize Early: Check reconstructions after 5-10 epochs

πŸ› Troubleshooting

Poor Reconstructions:

  • Decrease Ξ² value
  • Increase latent dimensions
  • Train for more epochs
  • Check learning rate

Not Disentangled:

  • Increase Ξ² gradually (4 β†’ 6 β†’ 8)
  • Ensure diverse training data
  • Increase model capacity
  • Train longer

Training Instability:

  • Reduce learning rate
  • Add gradient clipping
  • Check data normalization
  • Reduce batch size

πŸ“š References

πŸ“ Citation

If you use this implementation in your research, please cite:

@misc{bvae-implementation,
  author = Vishva MV,
  title = {Ξ²-VAE Implementation for Disentangled Representations},
  year = {2025},
  publisher = {GitHub},
  url = {https://github.com/Vishva2003/beta-vae}
}

🀝 Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes:

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/AmazingFeature)
  3. Commit your changes (git commit -m 'Add some AmazingFeature')
  4. Push to the branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

  • Original Ξ²-VAE authors for the groundbreaking research
  • Community contributors and researchers in disentangled representation learning

πŸ“§ Contact

For questions or feedback, please open an issue on GitHub or contact dev.vishvamv@mail.com


Happy Learning! πŸš€ If you find this useful, please consider starring ⭐ the repository!

About

An indepth tutorial on bvae

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors