-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathannotate.py
More file actions
106 lines (80 loc) · 3.56 KB
/
annotate.py
File metadata and controls
106 lines (80 loc) · 3.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import numpy as np
from pickle import load, dump
from transformers import EsmModel, EsmTokenizer, RobertaForTokenClassification
from model_utils import get_gene_label, get_embeddings, Seq4Transformer
from Bio.Seq import Seq
from Bio import SeqIO
import os
def annotate(cfg, args):
"""
Annotate a GenBank file with the CoreFinder model.
Args:
cfg (dict): Configuration dictionary.
args (argparse.Namespace): Command line arguments.
Returns:
{[gene_product], [gene_function], [gene_name]}: Dictionary containing the annotations.
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load the model and tokenizer
embedder_path = 'facebook/esm2_t33_650M_UR50D'
embedder = EsmModel.from_pretrained(embedder_path, output_hidden_states=True).to(device)
tokenizer = EsmTokenizer.from_pretrained(embedder_path)
input_file_path = cfg['input']
output_folder_path = cfg['output']
os.makedirs(output_folder_path, exist_ok=True)
# Load the sequences from the GenBank file
records = SeqIO.parse(args.input, 'genbank')
# read genbank / fasta file
seqs = []
names = []
if input_file_path.endswith('.gb') or input_file_path.endswith('.gbk'):
records = SeqIO.parse(args.input, 'genbank')
for record in records:
for feature in record.features:
if feature.type == 'CDS':
name = feature.qualifiers['locus_tag'][0]
seq = feature.qualifiers['translation'][0]
seqs.append(seq)
names.append(name)
elif input_file_path.endswith('.fasta') or input_file_path.endswith('.fa'):
records = SeqIO.parse(args.input, 'fasta')
for record in records:
name = record.id
seq = str(record.seq)
seqs.append(seq)
names.append(name)
else:
raise ValueError("Input file must be in GenBank (.gb/.gbk) or FASTA (.fasta/.fa) format.")
dataset = Seq4Transformer(seqs, tokenizer)
# Get the embeddings for the sequences
reps, token_type_ids = get_embeddings(embedder, dataset, device)
# Load the model for classification
corefinder_path = os.path.join(cfg['model'], 'corefinder_model')
model = RobertaForTokenClassification.from_pretrained(corefinder_path).to(device)
outputs = model(inputs_embeds=reps.unsqueeze(0).to(device),
token_type_ids=token_type_ids.unsqueeze(0).to(device)
).logits.squeeze().detach().cpu().numpy()
preds = {}
preds['gene_function'] = []
for i in range(len(outputs)):
pred = outputs[i]
if i == 0:
prod = np.array(pred)[:7]
prod_dict = ['Alkaloid', 'NRP', 'Other', 'Polyketide', 'RiPP', 'Saccharide','Terpene']
prod = np.exp(prod) / np.sum(np.exp(prod))
prod_label = prod_dict[prod.argmax()]
preds['product'] = prod_label
else:
gene_label = get_gene_label(pred)
preds['gene_function'].append(gene_label)
preds['gene_name'] = names
# print(preds)
# Save the predictions to a pickle file
output_file = os.path.join(output_folder_path, '{}_annotations.pkl'.format(os.path.basename(input_file_path)))
with open(output_file, 'wb') as f:
dump(preds, f)
print(f'Annotations saved to {output_file}')
print(preds)