This project implements a deep learning framework for predicting the binding affinity between drug molecules (ligands) and target proteins. The model leverages a hybrid architecture combining Graph Neural Networks (GNN) for molecular representation learning and Evolutionary Scale Modeling (ESM) for protein sequence embeddings. This approach allows for a comprehensive understanding of both the structural properties of small molecules and the biological context of protein targets.
- Graph Neural Network (GNN) for Ligands: Utilizes
TransformerConvlayers from PyTorch Geometric to process molecular graphs, capturing complex atomic interactions and structural features. - ESM Protein Embeddings: Integrates pre-trained ESM-2 (
esm2_t6_8M_UR50D) models to generate high-quality embeddings for protein sequences, ensuring robust representation of biological targets. - Hybrid Architecture: Concatenates ligand and protein representations to predict binding affinity values through a multi-layer perceptron (MLP) regressor.
- Automated Preprocessing Pipeline: Includes scripts for validating SMILES strings, generating protein embeddings, and preparing datasets for training.
- Comprehensive Evaluation: standardized evaluation metrics including Mean Squared Error (MSE), Root Mean Squared Error (RMSE), and R-squared (R²) scores.
DTBA/
├── data/
│ ├── processed/ # Processed PyTorch Geometric data files
│ └── raw/ # Raw input data (e.g., Ki_bind.tsv)
├── preprocessing/
│ ├── dataset.py # PyTorch Geometric dataset implementation
│ └── preprocess_drugs.py # Data cleaning and embedding generation script
├── results/ # Training visualizations and loss analysis
├── saved_models/ # Model checkpoints and best performing models
├── evaluate.py # Script for model evaluation
├── model.py # Neural network architecture definition
├── train.py # Main training loop and optimization
└── requirements.txt # Project dependencies
- Python 3.8+
- CUDA-enabled GPU (recommended for training)
-
Clone the repository:
git clone <repository-url> cd DTBA
-
Create a virtual environment (optional but recommended):
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
-
Install dependencies:
pip install -r requirements.txt
Note: You may need to install PyTorch and PyTorch Geometric specifically for your CUDA version. Please refer to the PyTorch website for specific instructions.
Before training, the raw data must be processed to generate protein embeddings and validate molecular structures.
python preprocessing/preprocess_drugs.pyThis script reads from data/raw/Ki_bind.tsv, filters invalid entries, generates ESM embeddings for proteins, and saves the train/test splits to data/raw/.
To train the model, run the training script. This will initialize the model, load the dataset, and begin the training process.
python train.pyThe script will:
- Load the processed dataset.
- Train the model for the specified number of epochs.
- Save the best model to
saved_models/best_model.pth. - Generate loss analysis plots in the
results/directory.
To evaluate the trained model on the test set:
python evaluate.pyThis will output performance metrics such as MSE, RMSE, and R² score, providing insights into the model's predictive accuracy.
The MainNetwork class in model.py defines the architecture:
-
Ligand Branch:
- Input: Molecular graph (node features, edge indices, edge attributes).
- Layers: Two
TransformerConvlayers with multi-head attention, followed by global mean and max pooling. - Output: A fixed-size vector representation of the ligand.
-
Protein Branch:
- Input: Pre-computed ESM embeddings (320 dimensions).
- Processing: Directly passed to the concatenation stage (extensible for further processing).
-
Interaction Module:
- The ligand and protein vectors are concatenated.
- Passed through a sequence of Linear layers with BatchNorm, ReLU activation, and Dropout for regularization.
- Final Output: Predicted binding affinity (scalar value).
Training progress and loss curves are automatically saved in the results/ folder. Key metrics tracked include:
- MSE (Mean Squared Error): Measures the average squared difference between estimated values and the actual value.
- R² Score: Represents the proportion of variance for the dependent variable that's explained by the independent variables.
This project is licensed under the MIT License. See the LICENSE file for details.