|
26 | 26 | from profold2.command.worker import main, autocast_ctx, WorkerModel, WorkerXPU |
27 | 27 |
|
28 | 28 |
|
29 | | -def _a3m_add_pseudo_linker(sequences, descriptions, pseudo_linker_len=100): |
30 | | - s, d = [], [] |
31 | | - |
32 | | - delta, domains = 1, [] |
33 | | - for seq, desc in zip(sequences, descriptions): |
34 | | - residue_index = torch.arange(len(seq), dtype=torch.int) |
35 | | - residue_index = parse_seq_index(desc, seq, residue_index) |
36 | | - residue_index = residue_index - residue_index[0] + delta |
37 | | - domains += [str_seq_index(residue_index)] |
38 | | - |
39 | | - s += [seq] |
40 | | - d += [desc.split()[0]] |
41 | | - delta = residue_index[-1] + pseudo_linker_len + 1 |
42 | | - |
43 | | - if domains: |
44 | | - domains = ','.join(domains) |
45 | | - d += [f':{domains}'] |
46 | | - return [''.join(s)], [' '.join(d)] |
47 | | - |
48 | | - |
49 | 29 | def _read_fasta(args): # pylint: disable=redefined-outer-name |
50 | 30 | def filename_get(fasta_file): |
51 | 31 | fasta_file = os.path.basename(fasta_file) |
@@ -82,33 +62,23 @@ def _create_dataloader(xpu, args): # pylint: disable=redefined-outer-name |
82 | 62 |
|
83 | 63 | sequences, descriptions, msa = [], [], [] |
84 | 64 | for fasta_name, fasta_str in _read_fasta(args): |
85 | | - if args.fasta_fmt == 'a4m': |
86 | | - s = fasta_str.splitlines() |
87 | | - d = [None] * len(s) |
88 | | - else: |
89 | | - s, d = parse_fasta(fasta_str) |
90 | | - if len(s) > 1 and args.fasta_fmt == 'single' and args.add_pseudo_linker: |
91 | | - # Add a pseudo linker between each chain. |
92 | | - s, d = _a3m_add_pseudo_linker(s, d, args.pseudo_linker_len) |
93 | | - assert len(s) == 1 |
| 65 | + s, d = parse_fasta(fasta_str) |
94 | 66 | d[0] = f'{fasta_name} {d[0]}' if exists(d[0]) else fasta_name |
95 | 67 | if args.fasta_fmt == 'single': |
96 | | - sequences += s |
97 | | - descriptions += d |
98 | | - msa += [None] * len(s) |
| 68 | + sequences += [s] |
| 69 | + descriptions += [d] |
| 70 | + msa += [[None] * len(s)] |
99 | 71 | else: |
100 | | - sequences += s[:1] |
101 | | - descriptions += d[:1] |
| 72 | + sequences += [s[:1]] |
| 73 | + descriptions += [d[:1]] |
102 | 74 | if len(s) > args.max_msa_size: |
103 | 75 | s = s[:1] + list( |
104 | 76 | np.random.choice( |
105 | 77 | s, size=args.max_msa_size - 1, replace=False |
106 | 78 | ) if args.max_msa_size > 1 else [] |
107 | 79 | ) |
108 | | - msa += [s] |
109 | | - data = ProteinSequenceDataset( |
110 | | - sequences, descriptions, msa=msa, domain_as_seq=args.add_pseudo_linker |
111 | | - ) |
| 80 | + msa += [[s]] |
| 81 | + data = ProteinSequenceDataset(sequences, descriptions, msa=msa) |
112 | 82 | if xpu.is_available() and WorkerXPU.world_size(args.nnodes) > 1: |
113 | 83 | kwargs['sampler'] = DistributedSampler( |
114 | 84 | data, |
@@ -181,7 +151,7 @@ def predict_structure(idx, batch): |
181 | 151 | print_fn=logging.info, |
182 | 152 | callback_fn=functools.partial(timing_callback, timings, 'predict_structure') |
183 | 153 | ): |
184 | | - logging.debug('Sequence %d shape %s: %s', idx, fasta_name, batch['seq'].shape) |
| 154 | + logging.debug('Sequence [%d] %s shape : %s', idx, fasta_name, batch['seq'].shape) |
185 | 155 | if args.fasta_fmt in ('a3m', 'a4m'): |
186 | 156 | logging.debug('msa shape %s: %s', fasta_name, batch['msa'].shape) |
187 | 157 |
|
@@ -340,23 +310,14 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name |
340 | 310 | '--fasta_fmt', |
341 | 311 | type=str, |
342 | 312 | default='single', |
343 | | - choices=['single', 'a3m', 'a4m'], |
| 313 | + choices=['single', 'a3m'], |
344 | 314 | help='format of fasta files.' |
345 | 315 | ) |
346 | 316 |
|
347 | 317 | parser.add_argument( |
348 | 318 | '--data_dir', type=str, default=None, help='load data from dataset.' |
349 | 319 | ) |
350 | 320 | parser.add_argument('--data_idx', type=str, default=None, help='dataset idx.') |
351 | | - parser.add_argument( |
352 | | - '--add_pseudo_linker', action='store_true', help='enable loading complex data.' |
353 | | - ) |
354 | | - parser.add_argument( |
355 | | - '--pseudo_linker_len', |
356 | | - type=int, |
357 | | - default=100, |
358 | | - help='add a pseudolinker with length=PSEUDO_LINKER_LEN.' |
359 | | - ) |
360 | 321 |
|
361 | 322 | parser.add_argument( |
362 | 323 | '--models', |
|
0 commit comments