Skip to content

feat: added cnn_lstm#18

Merged
gaurav12301010 merged 1 commit intoOPCODE-Open-Spring-Fest:mainfrom
adityacosmos24:main
Nov 15, 2025
Merged

feat: added cnn_lstm#18
gaurav12301010 merged 1 commit intoOPCODE-Open-Spring-Fest:mainfrom
adityacosmos24:main

Conversation

@adityacosmos24
Copy link
Copy Markdown
Contributor

Add CNN-LSTM Hybrid Model for ECG Arrhythmia Classification

Summary

This PR introduces a new hybrid deep learning model combining 2D Convolutional Neural Networks (CNN) with Bidirectional LSTM for ECG arrhythmia classification. The model achieves 98.91% test accuracy on the MIT-BIH Arrhythmia Database, providing an alternative architecture to the existing LSTM-only implementation.

What's New

Model Architecture

  • Hybrid CNN-LSTM Architecture: Combines 2D CNN layers for feature extraction with bidirectional LSTM layers for temporal sequence modeling
  • Architecture Details:
    • 2D CNN Feature Extractor: Two convolutional blocks with BatchNorm and MaxPooling
    • Bidirectional LSTM: 2-layer bidirectional LSTM (64 hidden units each direction)
    • Classification Head: Fully connected layers with dropout regularization

Key Features

  • ✅ Complete end-to-end pipeline from data download to model evaluation
  • ✅ Automated data preprocessing with bandpass filtering (0.5-40 Hz)
  • ✅ Beat segmentation around R-peaks (250 samples per beat)
  • ✅ Z-score normalization for each beat segment
  • ✅ Comprehensive evaluation metrics and visualizations
  • ✅ Confusion matrices (raw and normalized)
  • ✅ Sample prediction visualizations with confidence scores

Model Performance

Test Results

  • Test Accuracy: 98.91%
  • Weighted Precision: 98.90%
  • Weighted Recall: 98.91%
  • Weighted F1-Score: 98.90%

Per-Class Performance

Class Precision Recall F1-Score Support
F (Fusion) 82.93% 80.00% 81.44% 85
N (Normal) 99.31% 99.58% 99.44% 9,054
Q (Unknown) 98.31% 98.31% 98.31% 1,122
S (Supraventricular) 95.49% 92.28% 93.86% 298
V (Ventricular) 97.96% 96.28% 97.11% 698

Dataset

  • Source: MIT-BIH Arrhythmia Database (48 records)
  • Total Samples: 112,559 beat segments
  • Class Distribution:
    • N (Normal): 80.48%
    • Q (Unknown): 9.91%
    • V (Ventricular): 6.43%
    • S (Supraventricular): 2.47%
    • F (Fusion): 0.71%

Technical Details

Data Preprocessing

  • Bandpass filtering (0.5-40 Hz, 4th order Butterworth)
  • Beat segmentation: 100 samples before + 150 samples after R-peak
  • Z-score normalization per segment
  • Label grouping into 5 main categories (N, S, V, F, Q)

Training Configuration

  • Framework: PyTorch
  • Optimizer: Adam (lr=0.001)
  • Loss Function: CrossEntropyLoss
  • Batch Size: 128
  • Epochs: 50
  • Train/Val/Test Split: 80%/10%/10%
  • Device: CUDA (if available) or CPU

Model Architecture Breakdown

Input: (batch, 250, 1)
  ↓
2D CNN Block 1: Conv2d(1→32) + BatchNorm + ReLU + MaxPool
  ↓
2D CNN Block 2: Conv2d(32→64) + BatchNorm + ReLU + MaxPool
  ↓
Reshape: (batch, 62, 64)
  ↓
Bidirectional LSTM: 2 layers, 64 hidden units each direction
  ↓
FC Layers: 128 → 64 → 5 (with dropout 0.3)
  ↓
Output: 5-class classification

Files Added

  • cnn_lstm.ipynb: Complete notebook with data processing, model training, and evaluation

Generated Outputs

  • data/ecg_mitdb_processed.npz: Preprocessed dataset
  • sample_beats.png: Visualization of sample beats for each class
  • confusion_matrices.png: Raw and normalized confusion matrices
  • sample_predictions.png: Random sample predictions with confidence scores

Comparison with Existing LSTM Model

  • Architecture: Hybrid CNN-LSTM vs. Stacked LSTM
  • Framework: PyTorch vs. TensorFlow/Keras
  • Feature Extraction: CNN-based vs. LSTM-only
  • Performance: Comparable accuracy (~98.9% vs. ~98%)

Testing

  • Model trained and evaluated on MIT-BIH Arrhythmia Database
  • All visualizations generated successfully
  • Classification metrics computed and verified

Future Improvements

  • Add model checkpoint saving/loading
  • Implement early stopping
  • Add hyperparameter tuning capabilities
  • Include class imbalance handling (SMOTE, class weights)
  • Add model interpretability (attention visualization, Grad-CAM)

Dependencies

  • torch (PyTorch)
  • numpy
  • matplotlib
  • scipy
  • scikit-learn
  • wfdb
  • tqdm
  • seaborn

Note: This model provides an alternative approach to ECG classification using a hybrid CNN-LSTM architecture, complementing the existing LSTM-only implementation in the repository.

@gaurav12301010 gaurav12301010 merged commit e09e15d into OPCODE-Open-Spring-Fest:main Nov 15, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants