Skip to content

Latest commit

 

History

History
285 lines (215 loc) · 10.9 KB

File metadata and controls

285 lines (215 loc) · 10.9 KB

Topological Autoencoder with COAST Loss — Experiment Results

1. Dataset

Item Value
Source GEO Accession GSE62944 (TCGA RNA-seq, Rsubread TPM)
Cancer Type TCGA-BRCA (Breast Invasive Carcinoma)
Total Samples 1,215
Tumor Samples 1,102 (90.7%)
Normal Samples 113 (9.3%)
Original Genes ~23,368
Genes After Filtering 20,862
Bad Genes Removed 2,068 (
Batch Sources (TSS) 31 BRCA-specific Tissue Source Sites
Expression Unit TPM (Transcripts Per Million)
BRCA TSS Codes A1, A2, A7, A8, AC, AN, AO, AQ, AR, B6, BH, C8, D8, E2, E9, EW, GM, GI, HN, LD, LL, MS, OL, PE, PL, S3, UL, UU, WT, XX, Z7

Preprocessing Pipeline

Raw TPM (tumor + normal)
  → Merge & BRCA patient filter (by TSS code)
  → Bad gene removal (overall_bad_genes_4fold.txt, 2,068 genes)
  → Selective log1p (genes with |skew|>2.0 or kurtosis>10.0)
  → Zero-variance gene removal
  → GPU ComBat batch correction (TSS as batch, Target as covariate)
  → Output: 1,215 samples × 20,862 genes

2. Model Architecture

Topological Autoencoder (TAE)

Encoder:  20,862 → 1024 → 256 → z (latent_dim)
          [BN+LeakyReLU(0.2)] [BN+LeakyReLU(0.2)] [Linear, no activation]

Decoder:  z → 256 → 1024 → 20,862
          [BN+LeakyReLU(0.2)] [BN+LeakyReLU(0.2)] [ReLU (non-negative output)]
Layer Encoder Decoder
Hidden 1 Linear(20862, 1024) + BN + LeakyReLU(0.2) Linear(latent_dim, 256) + BN + LeakyReLU(0.2)
Hidden 2 Linear(1024, 256) + BN + LeakyReLU(0.2) Linear(256, 1024) + BN + LeakyReLU(0.2)
Output Linear(256, latent_dim) Linear(1024, 20862) + ReLU
  • BatchNorm: stabilizes training for high-dimensional tabular data
  • LeakyReLU(0.2): prevents dead neurons
  • Encoder output: no activation (full real-valued latent space)
  • Decoder output: ReLU (TPM values are non-negative)

3. Model Configurations

3.1 Distance Metrics for Topological Loss

Metric Formula Property
Euclidean D_ij = √Σ(x_ik - x_jk)² Absolute distance; suffers from curse of dimensionality
Cosine D_ij = 1 - (x·y)/(‖x‖·‖y‖) Angular divergence; magnitude-invariant
Pearson D_ij = 1 - r_ij (Pearson correlation) Co-expression pattern; shift+scale invariant

3.2 Loss Functions

Baseline — MSE Topological Loss:

L_total = L_recon + λ · L_topo
L_recon = MSE(x, x̂),   L_topo = MSE(D̃_latent, D̃_original)

where D̃ = D / max(D) (min-max normalized pairwise distance matrix).

Proposed — COAST (Cosine Optimized Adaptive Sinkhorn Transport) Loss:

  1. Convert distances to probability distributions via Gaussian RBF kernel:
P_ij = exp(-D²_ij / 2σ²) / Σ_k exp(-D²_ik / 2σ²)
  1. Compute entropy-regularized OT via Sinkhorn iterations (log-domain, 50 iters):
OT_ε(a, b) = min_π∈Π(a,b) ⟨C, π⟩ + ε · KL(π ‖ a⊗b)
  1. Debiased Sinkhorn divergence:
S(a, b) = OT(a, b) - ½·OT(a, a) - ½·OT(b, b) ≥ 0
  1. COAST loss with topo multiplier m and adaptive weighting (Kendall et al., 2018):
L_COAST = ½·e^(-s_r)·L_recon + ½·e^(-s_t)·m·L_topo + ½·s_r + ½·s_t

3.3 Adaptive Loss Weighting (Kendall et al., 2018)

Learns task-specific uncertainty s = log(σ²) for each loss term:

L = ½·e^(-s_r)·L_recon + ½·e^(-s_t)·L_topo + ½·s_r + ½·s_t
  • s_r, s_t: nn.Parameter, trained jointly with model via Adam
  • Automatically balances reconstruction vs. topology preservation
  • Used in all COAST experiments; baseline uses fixed λ = 1.0

3.4 Implementation Details

  • Envelope Theorem: Detach Sinkhorn dual variables (f, g) after convergence; backprop only through final cost (~50x VRAM reduction)
  • Log-domain Sinkhorn: All iterations use logsumexp to prevent numerical underflow
  • Learnable bandwidth σ: initialized to median pairwise distance, trained jointly

4. Training Configuration

Parameter Value
Optimizer Adam
Learning Rate 1e-4
Weight Decay 1e-5
Batch Size 64
Epochs 100
Validation Split 20% (stratified by class)
Topo Weight (baseline) 1.0
Classifier Probe LogisticRegression (every 10 epochs)
Best Model Selection Minimum validation total loss
GPU NVIDIA RTX 4060 (8GB VRAM)

Experimental Grid

Factor Values
Latent Dimension 16, 32, 64
Distance Metric Euclidean, Cosine, Pearson
Loss Type Standard MSE, COAST
Topo Multiplier (COAST) 3 values per config (log-scale random sweep)

Total: 27 configurations (9 baseline + 18 COAST)


5. Main Results

5.1 Full Results Table (sorted by Accuracy)

Rank Dim Distance Loss Topo Mult. Acc. AUC F1 Prec. Recall
1 16 Cosine COAST 45,708.8 98.77% 0.988 0.993 0.991 0.996
2 32 Cosine COAST 1,023.3 97.12% 0.989 0.984 0.986 0.982
3 16 Cosine Std MSE 95.47% 0.980 0.975 0.960 0.991
4 32 Pearson COAST 186.2 95.06% 0.986 0.973 0.964 0.982
5 32 Cosine COAST 26.3 94.65% 0.972 0.971 0.952 0.991
5 64 Cosine COAST 31,622.8 94.65% 0.972 0.971 0.960 0.982
7 64 Pearson COAST 25,704.0 94.24% 0.981 0.969 0.960 0.977
8 16 Euclidean Std MSE 93.83% 0.968 0.967 0.944 0.991
8 32 Pearson COAST 58,884.4 93.83% 0.980 0.966 0.956 0.977
8 16 Pearson COAST 97.7 93.83% 0.949 0.967 0.944 0.991
11 16 Cosine COAST 31.6 91.77% 0.968 0.956 0.920 0.996
11 64 Cosine COAST 1,621.8 91.77% 0.966 0.957 0.917 1.000
11 64 Cosine COAST 6,606.9 91.77% 0.965 0.957 0.917 1.000
14 64 Euclidean Std MSE 91.36% 0.976 0.954 0.920 0.991
14 64 Pearson COAST 11.0 91.36% 0.924 0.954 0.913 1.000
16 16 Pearson Std MSE 90.53%* 0.863 0.950 0.905 1.000
16 16 Pearson COAST 6,918.3 90.53%* 0.887 0.950 0.905 1.000
16 16 Pearson COAST 13,182.6 90.53%* 0.809 0.950 0.905 1.000
16 64 Pearson Std MSE 90.53%* 0.972 0.950 0.905 1.000
16 64 Pearson COAST 55.0 90.53%* 0.932 0.950 0.905 1.000
16 32 Cosine Std MSE 90.53%* 0.885 0.950 0.905 1.000
16 64 Cosine Std MSE 90.53%* 0.909 0.950 0.905 1.000
16 32 Cosine COAST 3,801.9 90.53%* 0.877 0.950 0.905 1.000
24 16 Cosine COAST 186.2 90.12% 0.934 0.947 0.926 0.968
25 32 Euclidean Std MSE 89.71% 0.957 0.945 0.911 0.982
25 32 Pearson Std MSE 89.71% 0.959 0.945 0.911 0.982
27 32 Pearson COAST 12,589.3 89.30% 0.924 0.943 0.911 0.977

* Degenerate classifier (predicts all samples as Tumor; Precision = class prior = 90.53%, Recall = 100%)

5.2 Baseline Comparison (Euclidean + Standard MSE)

Dim Accuracy AUC F1 Precision Recall
16 93.83% 0.968 0.967 0.944 0.991
32 89.71% 0.957 0.945 0.911 0.982
64 91.36% 0.976 0.954 0.920 0.991

Best baseline: dim16, 93.83%

5.3 Effect of Distance Metric (Standard MSE, no Sinkhorn)

Distance Dim 16 Dim 32 Dim 64
Euclidean 93.83% 89.71% 91.36%
Cosine 95.47% 90.53%* 90.53%*
Pearson 90.53%* 89.71% 90.53%*

Cosine distance at dim16 improves over Euclidean by +1.64%p without any other changes.

5.4 Effect of COAST Loss (Cosine distance, best multiplier per dim)

Dim Cosine + MSE Cosine + COAST Improvement
16 95.47% 98.77% (m=45,709) +3.30%p
32 90.53%* 97.12% (m=1,023) +6.59%p
64 90.53%* 94.65% (m=31,623) +4.12%p

COAST loss recovers degenerate classifiers at dim32/64 and pushes dim16 to 98.77%.

5.5 Best Model vs. Baseline

Baseline (Euc+MSE, dim16) Best (COAST, dim16) Delta
Accuracy 93.83% 98.77% +4.94%p
AUC 0.968 0.988 +0.020
F1 0.967 0.993 +0.026
Precision 0.944 0.991 +0.047
Recall 0.991 0.996 +0.005
Compression 20,862 → 16 (1,304:1) 20,862 → 16 (1,304:1) Same

6. COAST Model Details (Best Configuration)

Parameter Value
Latent Dimension 16
Distance Metric Cosine
Topo Multiplier 45,708.8
Best Epoch 99 / 100
Best Val Loss (total) 3,985.22
Best Val Recon Loss 9,232.30
Best Val Topo Loss 4.638
Learned w_recon 0.4308
Learned w_topo 0.4397
Learned log_var_recon 0.1489
Learned log_var_topo 0.1285
Learned sigma (bandwidth) 0.6935

Adaptive Weight Patterns Across All COAST Runs

Pattern w_recon w_topo Condition
Balanced ~0.431 ~0.44 High topo_multiplier (>1000)
Topo-heavy ~0.431 ~0.59 Low topo_multiplier (<200)

The network compensates for low multiplier by increasing w_topo, but there is a ceiling (~0.59) beyond which further compensation is insufficient.


7. Degenerate Classifier Analysis

11 of 27 configurations (40.7%) collapsed to degenerate classifiers that predict all samples as Tumor:

  • Accuracy = 90.53% (= Tumor class proportion)
  • Precision = 90.53%, Recall = 100%

Conditions that cause degeneration:

  • Pearson + Std MSE: 2/3 dims (all except dim16 which also marginally failed with AUC=0.863)
  • Cosine + Std MSE: 2/3 dims (dim32, dim64)
  • COAST with poor topo_multiplier: several cases

Conditions that prevent degeneration:

  • Euclidean + Std MSE: all 3 dims (stable but lower accuracy)
  • COAST with multiplier in ~1,000–50,000 range: consistently high performance

8. Key Findings

  1. Cosine > Euclidean > Pearson for gene expression topological loss (at dim16, standard MSE)
  2. COAST loss definitively outperforms MSE distance matching across all dimensions
  3. dim16 + COAST achieves best performance: 98.77% accuracy (+4.94%p over baseline)
  4. Topo multiplier is critical: sweet spot ~1,000–50,000 for cosine distance
  5. Adaptive weighting successfully balances loss terms (w_recon ~ w_topo at convergence)
  6. COAST recovers degenerate classifiers that fail under standard MSE
  7. Simple logistic regression on 16D latent vectors achieves near-perfect classification