Expanded chart

Research Documentation

Deep Learning-Based
ECG Arrhythmia Classification

Empirical validation of the Adam optimization algorithm (Kingma & Ba, 2015) applied to 5-class cardiac arrhythmia detection using a 1D Residual Neural Network trained on the MIT-BIH Arrhythmia Database.

PyTorch 2.x MIT-BIH · 109,446 beats 5-class classification Captum · Grad-CAM

Problem Statement

Cardiac arrhythmias affect over 1.5% of the global population and are a leading cause of sudden cardiac death. Manual interpretation of 24-hour Holter ECG recordings is time-consuming and subject to inter-observer variability. This project implements an end-to-end deep learning pipeline that classifies individual heartbeat segments into five clinically defined categories, with the specific research question:

"Does the Adam optimizer's theoretical advantages — adaptive per-parameter learning rates, bias-corrected moment estimates, and O(√T) regret bound — translate to measurably superior convergence and final accuracy compared to SGD, Adagrad, and RMSProp on a severe-class-imbalance ECG classification task?"

Representative Normal ECG Beat — 187 samples @ 360 Hz (~0.52 s)

P-wave (10–30) QRS complex (50–80) T-wave (90–130)

Dataset

Primary dataset: MIT-BIH Arrhythmia Database (Moody & Mark, 2001) — 48 half-hour two-lead Holter recordings from 47 subjects, sampled at 360 Hz. Each heartbeat is segmented into a fixed 187-sample window centred on the R-peak. Secondary dataset: PTB Diagnostic ECG DB used in notebook 08 for binary generalisation testing (healthy vs. myocardial infarction).

ClassLabelDescriptionCount% of totalChallenge
0 — NNormalNormal sinus rhythm90,58982.8%Dominant class
1 — SSupraventricularEctopic, narrow QRS2,7792.5%Subtle P-wave changes
2 — VVentricularPVC, wide bizarre QRS7,2366.6%Morphologically distinct
3 — FFusionN + V hybrid beat8030.7%Rarest · hardest to separate
4 — QUnknown / PacedPaced or unclassifiable8,0397.3%High morphological variability

The extreme imbalance (Normal 83% vs. Fusion 0.7%) motivates two specific design choices: Focal Loss to down-weight easy negatives during training, and SMOTE oversampling applied exclusively to the training split to prevent data leakage into validation.

Class Distribution — MIT-BIH (109,446 beats)

Preprocessing Pipeline

All preprocessing is implemented in src/data/preprocessor.py. The pipeline runs on the training split only; the test split receives only Z-score normalisation to prevent leakage.

Raw CSV (187 values/row, integer amplitude)
↓ Z-score normalise per-sample: (x - μ) / σ — removes patient-level amplitude differences
↓ IQR outlier removal eliminates Holter digitisation artefacts (|z| > 3 IQR)
↓ Stratified 85/15 split preserves class ratios in train and val
↓ SMOTE (train only) synthetic minority oversampling — balances class counts
↓ Reshape → (N, 1, 187) PyTorch Conv1d format: (batch, channels, length)

Per-sample Z-score normalisation is particularly important for this task because ECG amplitude varies across patients and recording conditions. Using population-level statistics would introduce data leakage and reduce generalisation. At inference time, the same per-sample normalisation is applied — no stored training statistics are required.

Model Architecture — 1D ResNet

The primary model (src/models/resnet1d.py) is a 1D Residual Network adapted from He et al. (2016) for time-series classification. Residual connections allow gradient flow through deep networks and prevent vanishing gradients during training with Adam's adaptive updates. Total parameters: ~790,000.

Input: (batch, 1, 187)
Stem: Conv1d(1→32, k=7, s=1) → BN → ReLU
layer1: ResidualBlock(32, s=1) × 2 low-level: P-wave & QRS shape
layer2: ResidualBlock(64, s=2) × 2 mid-level: beat intervals, rhythm
layer3: ResidualBlock(128, s=2) × 2 high-level: arrhythmia class patterns
Global Average Pooling (temporal dim)
Dense(128) → ReLU → Dropout(0.4)
Dense(5) → Softmax

Each ResidualBlock follows the pattern: Conv1d → BN → ReLU → Conv1d → BN → (+skip connection) → ReLU. Stride-2 blocks use a 1×1 projection shortcut to match dimensions. The Grad-CAM target layer is layer3[-1].conv2, exposed via model.get_gradcam_target_layer().

Baseline for comparison: a 3-block 1D CNN (src/models/baseline_cnn.py) with no residual connections, used only in the optimizer comparison study to ensure architecture complexity does not confound optimizer performance differences.

Network Architecture — Visual Overview

Training Setup

Adam Hyperparameters (paper defaults)

ParameterValueSource
Learning rate α0.001Algorithm 1
β₁ (1st moment)0.9Algorithm 1
β₂ (2nd moment)0.999Algorithm 1
ε (numerical stab.)1e-8Algorithm 1
Weight decay0

Training Configuration

SettingValue
Batch size128
Max epochs100
Early stoppingpatience = 15 (val macro F1)
LR scheduleCosineAnnealingLR (T_max=100)
Focal Loss γ2.0
Focal Loss αinverse class frequency

Focal Loss (Lin et al., 2017): FL(pₜ) = −αₜ · (1−pₜ)^γ · log(pₜ). The modulating factor (1−pₜ)^γ reduces the loss contribution of well-classified easy examples, forcing the model to focus on hard minority-class beats. With γ=2.0, a sample classified with 90% confidence contributes only 1% of the cross-entropy loss.

Early stopping monitors validation macro F1 rather than accuracy, because accuracy is misleading under class imbalance — a model predicting only "Normal" achieves 83% accuracy while being clinically useless.

Focal Loss vs. Cross-Entropy — Effect of Modulating Factor (1−pₜ)^γ

At p=0.9 (confident prediction): CE loss ≈ 0.105 · FL(γ=2) ≈ 0.001 — 99% reduction forces focus on hard examples.

Optimizer Comparison Study

Five optimizers are trained through the identical Trainer loop (src/optimizers/comparison.py) on the same data split and random seed. Predictions from the Adam paper (§4, §5, §6) are verified empirically.

OptimizerAdam Paper PredictionExpected Macro F1Why
SGD (vanilla) Stalls on sparse minority gradients (§6.1) ~72% No adaptation — Fusion & SVE nearly invisible
SGD + Momentum Faster but no per-parameter adaptation (§5) ~73% Momentum helps convergence speed, not minority classes
Adagrad Monotonically decaying LR stalls late (§5) ~78% LR → 0 prevents recovery from early suboptimal states
RMSProp Diverges without bias correction (§6.4) ~84% Lacks bias correction → unstable early training
Adam O(√T) regret · all advantages combined (§4) ≥93% Adaptive LR + bias correction + momentum

The sparse gradient problem (§6.1) is particularly relevant here: the Fusion class (0.7%) produces gradients only in 7 of every 1000 batches on average. Adam's bias-corrected running averages ensure these rare gradients are not drowned out, whereas vanilla SGD effectively ignores them.

Expected Macro F1 by Optimizer — Adam Paper Predictions (§4, §5, §6.1, §6.4)

Grad-CAM Explainability

Gradient-weighted Class Activation Mapping is implemented via Captum's LayerGradCam targeting model.layer3[-1].conv2 — the final convolutional layer before global average pooling, where the highest-level feature representations are formed.

1. Forward pass: input (1, 1, 187) → predicted class score
2. Backprop gradients to layer3[-1].conv2 feature maps
3. Global-average gradients over temporal dim → channel weights αₖ
4. Weighted sum of activation maps → 1D CAM (length L')
5. ReLU → keep only positive contributions
6. Linear interpolation upsample → length 187
7. Normalise to [0, 1] · overlay on raw signal (RdYlBu_r colormap)

Clinical verification criteria:

Class S — Supraventricular

Heatmap should peak at P-wave region (samples 10–30). Supraventricular arrhythmias originate above the ventricles; the diagnostic signal is the abnormal P-wave morphology preceding the (narrow) QRS complex.

Class V — Ventricular

Heatmap should peak at QRS complex (samples 50–80). Ventricular ectopic beats are characterised by a wide, bizarre QRS produced by aberrant ventricular depolarisation — the defining morphological feature.

Expected Results

Target performance on the held-out MIT-BIH test set (15% stratified split):

ClassArrhythmiaDataset %Expected F1Notes
0 — NNormal82.8%~99%Dominant class, easy to learn
1 — SSupraventricular2.5%~88%Subtle P-wave changes
2 — VVentricular6.6%~95%Distinct morphology
3 — FFusion Beat0.7%~82%Rarest · SMOTE critical here
4 — QUnknown / Paced7.3%~91%Variable morphology
Overall ≥98% acc · ≥93% macro F1 Adam vs SGD gap: ~21% macro F1

The large Adam vs. SGD performance gap on macro F1 (predicted ~21 percentage points) is the primary empirical finding. Accuracy alone would not reveal this — both SGD and Adam achieve high accuracy by correctly classifying the dominant Normal class, but only Adam successfully learns the rare minority classes.

Achieved Results

Final evaluation on the held-out MIT-BIH test set (15% stratified split, ~16,000 samples). The model was trained with Adam + CosineAnnealingLR for 100 epochs; the best checkpoint (epoch 98, val macro F1 = 0.9075) was evaluated once on the unseen test set.

98.20%
Accuracy
target ≥ 98% ✓
91.40%
Macro F1
target ≥ 93% — close
98.23%
Weighted F1
imbalance-adjusted
98.92%
Mean ROC-AUC
one-vs-rest, 5 classes

Per-Class Precision · Recall · F1 — Test Set

Supraventricular (S) and Fusion (F) are the hardest classes despite SMOTE augmentation — both have subtle or hybrid morphology and together represent only 3.2% of training samples. Normal and Unknown achieve near-perfect F1 because they have either dominant frequency or highly consistent waveform signatures.

Adam Training Curve — Val F1 over 100 Epochs

All Optimizers — Validation F1 Convergence

Adam converges noisily — it explores broadly for 60 epochs before stabilising above 0.90. SGD+Momentum reaches its peak earlier (epoch 33) and is the best optimizer by val F1 (0.9217), contradicting the common assumption that Adam always wins on deep learning tasks. RMSProp is the most unstable, with variance spikes throughout training.

Best Val Macro F1 Achieved — Per Optimizer (Actual)

SGD+Momentum outperforms Adam on this dataset. This is likely because CosineAnnealingLR favours optimizers that tolerate larger learning-rate swings — SGD+Momentum rides the cosine restarts more aggressively, escaping local minima that stall Adam's adaptive steps. The practical takeaway: adaptive optimizers are not universally superior; scheduling strategy and dataset size both mediate the outcome.

Confusion Matrix — Test Set (ResNet1D + Adam)

Confusion Matrix

ROC Curves — One-vs-Rest, All 5 Classes

ROC Curves

The confusion matrix confirms the model's primary failure mode: S beats misclassified as N. This is clinically meaningful — supraventricular ectopics (premature atrial contractions) can appear morphologically similar to normal sinus rhythm if the P-wave alteration is subtle. All 5 ROC curves cluster above 0.97 AUC, demonstrating that the model has strong discriminative power even for minority classes.

Cross-Dataset Generalisation — PTB Diagnostic ECG Database

PTB-DB ROC Curve
99.64%
PTB-DB F1
Healthy vs. Myocardial Infarction

The model was fine-tuned on PTB Diagnostic ECG DB (binary classification: healthy vs. MI). Achieving 99.64% F1 on an entirely different dataset validates that the deep residual features learnt on MIT-BIH (rhythm-based arrhythmias) generalise to morphology-based pathology detection — the network is capturing fundamental ECG structure, not dataset- specific artefacts.

Notebook Pipeline

Each notebook is self-contained but builds on outputs from the previous one. Run in sequence from the project root.

01

01_eda.ipynb — Exploratory Data Analysis

Class distribution, waveform morphology, amplitude statistics. Saves reference_waveforms.npz used by the app.

02

02_preprocessing.ipynb — Data Pipeline

Z-score normalisation, IQR filtering, stratified split, SMOTE. Validates no leakage between splits.

03

03_baseline_cnn.ipynb — Baseline Model

3-layer CNN trained with Adam. Establishes a lower-bound performance benchmark for the optimizer comparison.

04

04_resnet_training.ipynb — Primary Model

Full ResNet1D training with Adam. Saves resnet1d_adam_best.pth — required by the web app.

05

05_optimizer_comparison.ipynb — Optimizer Study

Trains SGD / SGD+M / Adagrad / RMSProp / Adam through identical Trainer loop. Plots convergence curves and per-class F1.

06

06_evaluation.ipynb — Full Evaluation

Confusion matrix, per-class F1, ROC-AUC curves, full classification report on the held-out test set.

07

07_gradcam_explainability.ipynb — Explainability

1 correct + 1 misclassified Grad-CAM example per class (10 plots total). Clinical verification of attended ECG regions.

08

08_ptbdb_generalisation.ipynb — Generalisation

Tests model on PTB Diagnostic ECG DB (binary: healthy vs. MI). Validates that features learnt on MIT-BIH transfer across datasets.

References

[1] Kingma, D. P., & Ba, J. (2015). Adam: A Method for Stochastic Optimization. ICLR 2015. arXiv:1412.6980.
[2] Moody, G. B., & Mark, R. G. (2001). The impact of the MIT-BIH Arrhythmia Database. IEEE Engineering in Medicine and Biology, 20(3), 45–50.
[3] Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal Loss for Dense Object Detection. IEEE ICCV. arXiv:1708.02002.
[4] He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. IEEE CVPR. arXiv:1512.03385.
[5] Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P. (2002). SMOTE: Synthetic Minority Over-sampling Technique. JAIR, 16, 321–357.
[6] Selvaraju, R. R., et al. (2017). Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization. IEEE ICCV. arXiv:1610.02391.
[7] Kokhlikyan, N., et al. (2020). Captum: A unified and generic model interpretability library for PyTorch. arXiv:2009.07896.