-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpreprocess.py
More file actions
150 lines (112 loc) · 4.84 KB
/
preprocess.py
File metadata and controls
150 lines (112 loc) · 4.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Data preprocessing script for protein sequences.
Converts FASTA files to the required format for training.
"""
import argparse
import pickle
from pathlib import Path
from typing import List, Tuple, Dict
import numpy as np
import torch
import esm
def parse_args():
parser = argparse.ArgumentParser(description="Preprocess protein data")
parser.add_argument("--input", type=str, required=True, help="Input FASTA file")
parser.add_argument("--output", type=str, required=True, help="Output directory")
parser.add_argument("--split", type=str, choices=['train', 'valid', 'test'],
required=True, help="Data split type")
return parser.parse_args()
def parse_fasta_file(fasta_path: str) -> List[Dict]:
"""Parse a custom 3-line FASTA-like format (not standard FASTA)."""
proteins = []
with open(fasta_path, 'r') as f:
lines = [line.strip() for line in f if line.strip()]
assert len(lines) % 3 == 0, "File must contain groups of 3 lines: >name, seq, labels"
for i in range(0, len(lines), 3):
assert lines[i].startswith('>'), f"Line {i+1} should start with '>'"
proteins.append({
'name': lines[i][1:],
'sequence': lines[i + 1],
'labels': lines[i + 2]
})
return proteins
def create_protein_list(proteins: List[Dict]) -> List[Tuple]:
"""Create protein list for dataset."""
protein_list = []
for protein_idx, protein in enumerate(proteins):
sequence = protein['sequence']
name = protein['name']
seq_length = len(sequence)
for residue_idx in range(seq_length):
protein_list.append((
len(protein_list), # count
protein_idx, # id_idx
residue_idx, # ii (position in sequence)
'processed', # dset
name, # protein_id
seq_length # seq_length
))
return protein_list
def main():
args = parse_args()
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Parse input file
proteins = parse_fasta_file(args.input)
print(f"Parsed {len(proteins)} proteins")
# Create protein list
protein_list = create_protein_list(proteins)
# Initialize ESM2 model (esm2_t36_3B_UR50D -> 2560-dim embeddings)
print("Loading ESM2 model (esm2_t36_3B_UR50D)...")
model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
batch_converter = alphabet.get_batch_converter()
# Generate embeddings and labels
embeddings_all_proteins: List[np.ndarray] = []
labels_all_proteins: List[np.ndarray] = []
with torch.no_grad():
for idx, protein in enumerate(proteins):
name = protein['name']
seq = protein['sequence']
seq_len = len(seq)
if seq_len == 0:
embeddings_all_proteins.append(np.empty((0, 2560), dtype=np.float32))
labels_all_proteins.append(np.empty((0,), dtype=np.int32))
continue
batch_data = [(name, seq)]
_, _, batch_tokens = batch_converter(batch_data)
batch_tokens = batch_tokens.to(device)
results = model(batch_tokens, repr_layers=[36], return_contacts=False)
token_representations = results["representations"][36]
per_residue = token_representations[0, 1:seq_len + 1, :].detach()
embedding = per_residue.to(dtype=torch.float32).cpu().numpy()
label_seq = protein['labels']
labels = np.asarray(
[1 if label_seq[pos] == '1' else 0 for pos in range(seq_len)],
dtype=np.int32
)
embeddings_all_proteins.append(embedding)
labels_all_proteins.append(labels)
if (idx + 1) % 10 == 0 or idx == 0:
print(f"Encoded {idx + 1}/{len(proteins)} proteins")
# Save processed data scoped under the requested split directory
split_dir = output_dir / args.split
split_dir.mkdir(parents=True, exist_ok=True)
input_stem = Path(args.input).stem
encode_file = split_dir / f"{input_stem}-ESM2.pkl"
label_file = split_dir / f"{input_stem}-label.pkl"
list_file = split_dir / f"{input_stem}-list.pkl"
with open(encode_file, 'wb') as f:
pickle.dump(embeddings_all_proteins, f)
with open(label_file, 'wb') as f:
pickle.dump(labels_all_proteins, f)
with open(list_file, 'wb') as f:
pickle.dump(protein_list, f)
print(f"Saved processed data to {split_dir}")
print(f"- Encodings: {encode_file}")
print(f"- Labels: {label_file}")
print(f"- Protein list: {list_file}")
if __name__ == "__main__":
main()