-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfeaturizer.py
More file actions
165 lines (127 loc) · 5.03 KB
/
featurizer.py
File metadata and controls
165 lines (127 loc) · 5.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
'''
Author: DengRui
Date: 2024-1-20 13:15:08
LastEditors: DengRui
LastEditTime: 2024-1-20 13:23:06
FilePath: /DeepSub/embedding/esm_embedding_esm2.py
Description: using esm2 embedding seqs
Copyright (c) 2024 by DengRui, All Rights Reserved.
'''
import pandas as pd
import numpy as np
import esm
import torch
import os
from tqdm import tqdm
from tool import config as cfg
# Set gpu
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
# Load data
def load_data(path=cfg.DATA_PATH):
"""
Load the dataset from the specified path and return a list, where each element is a tuple containing the uniprot_id and seq.
Args:
path (str): The path to the dataset, defaulting to cfg.DATA_PATH.
Returns:
List[Tuple[int, str]]: A list of tuples containing the uniprot_id and seq.
"""
dataset = pd.read_csv(path)
# dataset.Sequence = dataset.Sequence.apply(lambda x :x[:10000])
# dataset = dataset.sample(1000)
dataset = dataset.rename(columns={'Entry':'uniprot_id','Sequence':'seq'})
df_data = list(zip(dataset.uniprot_id.index,dataset.seq))
return df_data,dataset
# Set model
def set_model():
"""
Set the model to be used for embedding.
Args:
None
Returns:
esm.pretrained.ESM: The pre-trained ESM model.
"""
esm.pretrained.esm2_t33_650M_UR50D()
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()
model = model.to(device)
return model,batch_converter,alphabet
# Sequences embedding
def get_rep_seq(sequences,model,batch_converter,alphabet):
"""
Embedding sequences using the given model.
Args:
sequences (list): The list of sequences to be embedded.
model (esm.pretrained.ESM): The pre-trained ESM model.
batch_converter (esm.pretrained.BatchConverter): The batch converter for the given model.
alphabet (esm.pretrained.Alphabet): The alphabet for the given model.
Returns:
pd.DataFrame: The embedding results for the given sequences.
"""
batch_labels, batch_strs, batch_tokens = batch_converter(sequences)
batch_tokens = batch_tokens.to(device)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
# Average on the protein length, to obtain a single vector per fasta
sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
np_list = []
# Detach the tensors to obtain a numpy array
for i, ten in enumerate(sequence_representations):
ten=ten.cpu().detach().numpy()
np_list.append(ten)
res = pd.DataFrame(np_list)
res.columns = ['f'+str(i) for i in range (0,res.shape[1])]
torch.cuda.empty_cache()
return res
def save_feature(folder_path_feature,res):
"""
Save the embedding results to a feather file.
Args:
folder_path_feature (str): The path to the folder where the feather file will be saved.
res (pd.DataFrame): The embedding results for the given sequences.
Returns: None
"""
res.to_feather(f'{folder_path_feature}feature_esm2.feather')
def main():
""" Perform embedding on the given dataset and process it in batches.
Args:
df_data (pd.DataFrame): The dataset to be processed, containing two columns:
sequences and uniprot_ids. stride (int, optional): The step size for batch processing.
Defaults to 2. num_iterations (int, optional): The number of iterations for processing. Defaults to None.
Returns: None
"""
# Load data
df_data,dataset = load_data()
# Set model
model,batch_converter,alphabet = set_model()
# check dir
folder_path_feature = cfg.FEATURE_PATH
if not os.path.exists(folder_path_feature):
os.makedirs(folder_path_feature)
# Embedding
stride = 2
num_iterations = len(df_data) // stride
if len(df_data) % stride != 0:
num_iterations += 1
all_results = pd.DataFrame()
for i in tqdm(range(num_iterations)):
start = i * stride
end = start + stride
current_data = df_data[start:end]
rep33 = get_rep_seq(current_data,model,batch_converter,alphabet)
rep33['uniprot_id'] = dataset[start:end].uniprot_id.tolist()
cols = list(rep33.columns)
cols = [cols[-1]] + cols[:-1]
rep33 = rep33[cols]
all_results = pd.concat([all_results, rep33], ignore_index=True)
if end%500 == 0:
all_results.to_feather(f'{folder_path_feature}feature_esm2_checkpoint.feather')
# save feature
save_feature(folder_path_feature,all_results)
if __name__ == '__main__':
main()