|
11 | 11 | from argparse import ArgumentParser |
12 | 12 | from pathlib import Path |
13 | 13 | from typing import List, Union |
14 | | -import os |
| 14 | +import os, pandas |
15 | 15 |
|
16 | | -from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise |
| 16 | +from colabfold.inputs import get_queries, msa_to_str, parse_fasta |
| 17 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
17 | 18 |
|
18 | 19 | logger = logging.getLogger(__name__) |
19 | 20 |
|
| 21 | +def get_queries_pairwise( |
| 22 | + input_path: Union[str, Path], sort_queries_by: str = "length" |
| 23 | +) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]: |
| 24 | + """Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple |
| 25 | + of job name, sequence and the optional a3m lines""" |
| 26 | + input_path = Path(input_path) |
| 27 | + if not input_path.exists(): |
| 28 | + raise OSError(f"{input_path} could not be found") |
| 29 | + if input_path.is_file(): |
| 30 | + if input_path.suffix == ".csv" or input_path.suffix == ".tsv": |
| 31 | + sep = "\t" if input_path.suffix == ".tsv" else "," |
| 32 | + df = pandas.read_csv(input_path, sep=sep) |
| 33 | + assert "id" in df.columns and "sequence" in df.columns |
| 34 | + queries = [ |
| 35 | + (str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None) |
| 36 | + for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0 |
| 37 | + ] |
| 38 | + for i in range(len(queries)): |
| 39 | + if len(queries[i][1]) == 1: |
| 40 | + queries[i] = (queries[i][0], queries[i][1][0], None) |
| 41 | + elif input_path.suffix == ".a3m": |
| 42 | + raise NotImplementedError() |
| 43 | + elif input_path.suffix in [".fasta", ".faa", ".fa"]: |
| 44 | + (sequences, headers) = parse_fasta(input_path.read_text()) |
| 45 | + queries = [] |
| 46 | + for i, (sequence, header) in enumerate(zip(sequences, headers)): |
| 47 | + sequence = sequence.upper() |
| 48 | + if sequence.count(":") == 0: |
| 49 | + # Single sequence |
| 50 | + if i==0: |
| 51 | + continue |
| 52 | + queries.append((headers[0]+'&'+header, [sequences[0],sequence], None)) |
| 53 | + else: |
| 54 | + # Complex mode |
| 55 | + queries.append((header, sequence.upper().split(":"), None)) |
| 56 | + else: |
| 57 | + raise ValueError(f"Unknown file format {input_path.suffix}") |
| 58 | + else: |
| 59 | + raise NotImplementedError() |
| 60 | + |
| 61 | + is_complex = True |
| 62 | + return queries, is_complex |
20 | 63 |
|
21 | 64 | def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]): |
22 | 65 | params_log = " ".join(str(i) for i in params) |
@@ -61,8 +104,9 @@ def mmseqs_search_monomer( |
61 | 104 | used_dbs.append(template_db) |
62 | 105 | if use_env: |
63 | 106 | used_dbs.append(metagenomic_db) |
64 | | - |
65 | 107 | for db in used_dbs: |
| 108 | + if str(db) == '.': |
| 109 | + continue |
66 | 110 | if not dbbase.joinpath(f"{db}.dbtype").is_file(): |
67 | 111 | raise FileNotFoundError(f"Database {db} does not exist") |
68 | 112 | if ( |
|
0 commit comments