Skip to content

Commit a27a87f

Browse files
authored
Merge pull request #333 from bigict/data
fix: Dataset for inference
2 parents 543fd36 + fe0ec1c commit a27a87f

2 files changed

Lines changed: 163 additions & 221 deletions

File tree

profold2/command/predictor.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,6 @@
2626
from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU
2727

2828

29-
def _a3m_add_pseudo_linker(sequences, descriptions, pseudo_linker_len=100):
30-
s, d = [], []
31-
32-
delta, domains = 1, []
33-
for seq, desc in zip(sequences, descriptions):
34-
residue_index = torch.arange(len(seq), dtype=torch.int)
35-
residue_index = parse_seq_index(desc, seq, residue_index)
36-
residue_index = residue_index - residue_index[0] + delta
37-
domains += [str_seq_index(residue_index)]
38-
39-
s += [seq]
40-
d += [desc.split()[0]]
41-
delta = residue_index[-1] + pseudo_linker_len + 1
42-
43-
if domains:
44-
domains = ','.join(domains)
45-
d += [f':{domains}']
46-
return [''.join(s)], [' '.join(d)]
47-
48-
4929
def _read_fasta(args): # pylint: disable=redefined-outer-name
5030
def filename_get(fasta_file):
5131
fasta_file = os.path.basename(fasta_file)
@@ -82,33 +62,23 @@ def _create_dataloader(xpu, args): # pylint: disable=redefined-outer-name
8262

8363
sequences, descriptions, msa = [], [], []
8464
for fasta_name, fasta_str in _read_fasta(args):
85-
if args.fasta_fmt == 'a4m':
86-
s = fasta_str.splitlines()
87-
d = [None] * len(s)
88-
else:
89-
s, d = parse_fasta(fasta_str)
90-
if len(s) > 1 and args.fasta_fmt == 'single' and args.add_pseudo_linker:
91-
# Add a pseudo linker between each chain.
92-
s, d = _a3m_add_pseudo_linker(s, d, args.pseudo_linker_len)
93-
assert len(s) == 1
65+
s, d = parse_fasta(fasta_str)
9466
d[0] = f'{fasta_name} {d[0]}' if exists(d[0]) else fasta_name
9567
if args.fasta_fmt == 'single':
96-
sequences += s
97-
descriptions += d
98-
msa += [None] * len(s)
68+
sequences += [s]
69+
descriptions += [d]
70+
msa += [[None] * len(s)]
9971
else:
100-
sequences += s[:1]
101-
descriptions += d[:1]
72+
sequences += [s[:1]]
73+
descriptions += [d[:1]]
10274
if len(s) > args.max_msa_size:
10375
s = s[:1] + list(
10476
np.random.choice(
10577
s, size=args.max_msa_size - 1, replace=False
10678
) if args.max_msa_size > 1 else []
10779
)
108-
msa += [s]
109-
data = ProteinSequenceDataset(
110-
sequences, descriptions, msa=msa, domain_as_seq=args.add_pseudo_linker
111-
)
80+
msa += [[s]]
81+
data = ProteinSequenceDataset(sequences, descriptions, msa=msa)
11282
if xpu.is_available() and WorkerXPU.world_size(args.nnodes) > 1:
11383
kwargs['sampler'] = DistributedSampler(
11484
data,
@@ -181,7 +151,7 @@ def predict_structure(idx, batch):
181151
print_fn=logging.info,
182152
callback_fn=functools.partial(timing_callback, timings, 'predict_structure')
183153
):
184-
logging.debug('Sequence %d shape %s: %s', idx, fasta_name, batch['seq'].shape)
154+
logging.debug('Sequence [%d] %s shape : %s', idx, fasta_name, batch['seq'].shape)
185155
if args.fasta_fmt in ('a3m', 'a4m'):
186156
logging.debug('msa shape %s: %s', fasta_name, batch['msa'].shape)
187157

@@ -340,23 +310,14 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
340310
'--fasta_fmt',
341311
type=str,
342312
default='single',
343-
choices=['single', 'a3m', 'a4m'],
313+
choices=['single', 'a3m'],
344314
help='format of fasta files.'
345315
)
346316

347317
parser.add_argument(
348318
'--data_dir', type=str, default=None, help='load data from dataset.'
349319
)
350320
parser.add_argument('--data_idx', type=str, default=None, help='dataset idx.')
351-
parser.add_argument(
352-
'--add_pseudo_linker', action='store_true', help='enable loading complex data.'
353-
)
354-
parser.add_argument(
355-
'--pseudo_linker_len',
356-
type=int,
357-
default=100,
358-
help='add a pseudolinker with length=PSEUDO_LINKER_LEN.'
359-
)
360321

361322
parser.add_argument(
362323
'--models',

0 commit comments

Comments
 (0)