Skip to content

Commit 7ad01b8

Browse files
author
dohyun-s
committed
fix search.py
1 parent b975e5f commit 7ad01b8

1 file changed

Lines changed: 47 additions & 3 deletions

File tree

colabfold/mmseqs/search.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,55 @@
1111
from argparse import ArgumentParser
1212
from pathlib import Path
1313
from typing import List, Union
14-
import os
14+
import os, pandas
1515

16-
from colabfold.batch import get_queries, msa_to_str, get_queries_pairwise
16+
from colabfold.inputs import get_queries, msa_to_str, parse_fasta
17+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1718

1819
logger = logging.getLogger(__name__)
1920

21+
def get_queries_pairwise(
22+
input_path: Union[str, Path], sort_queries_by: str = "length"
23+
) -> Tuple[List[Tuple[str, str, Optional[List[str]]]], bool]:
24+
"""Reads a directory of fasta files, a single fasta file or a csv file and returns a tuple
25+
of job name, sequence and the optional a3m lines"""
26+
input_path = Path(input_path)
27+
if not input_path.exists():
28+
raise OSError(f"{input_path} could not be found")
29+
if input_path.is_file():
30+
if input_path.suffix == ".csv" or input_path.suffix == ".tsv":
31+
sep = "\t" if input_path.suffix == ".tsv" else ","
32+
df = pandas.read_csv(input_path, sep=sep)
33+
assert "id" in df.columns and "sequence" in df.columns
34+
queries = [
35+
(str(df["id"][0])+'&'+str(seq_id), [df["sequence"][0].upper(),sequence.upper()], None)
36+
for i, (seq_id, sequence) in enumerate(df[["id", "sequence"]].itertuples(index=False)) if i!=0
37+
]
38+
for i in range(len(queries)):
39+
if len(queries[i][1]) == 1:
40+
queries[i] = (queries[i][0], queries[i][1][0], None)
41+
elif input_path.suffix == ".a3m":
42+
raise NotImplementedError()
43+
elif input_path.suffix in [".fasta", ".faa", ".fa"]:
44+
(sequences, headers) = parse_fasta(input_path.read_text())
45+
queries = []
46+
for i, (sequence, header) in enumerate(zip(sequences, headers)):
47+
sequence = sequence.upper()
48+
if sequence.count(":") == 0:
49+
# Single sequence
50+
if i==0:
51+
continue
52+
queries.append((headers[0]+'&'+header, [sequences[0],sequence], None))
53+
else:
54+
# Complex mode
55+
queries.append((header, sequence.upper().split(":"), None))
56+
else:
57+
raise ValueError(f"Unknown file format {input_path.suffix}")
58+
else:
59+
raise NotImplementedError()
60+
61+
is_complex = True
62+
return queries, is_complex
2063

2164
def run_mmseqs(mmseqs: Path, params: List[Union[str, Path]]):
2265
params_log = " ".join(str(i) for i in params)
@@ -61,8 +104,9 @@ def mmseqs_search_monomer(
61104
used_dbs.append(template_db)
62105
if use_env:
63106
used_dbs.append(metagenomic_db)
64-
65107
for db in used_dbs:
108+
if str(db) == '.':
109+
continue
66110
if not dbbase.joinpath(f"{db}.dbtype").is_file():
67111
raise FileNotFoundError(f"Database {db} does not exist")
68112
if (

0 commit comments

Comments
 (0)