diff --git a/colabfold/mmseqs/search.py b/colabfold/mmseqs/search.py index 6455f4ae..c0dffcc7 100644 --- a/colabfold/mmseqs/search.py +++ b/colabfold/mmseqs/search.py @@ -390,7 +390,10 @@ def main(): help="Database preload mode 0: auto, 1: fread, 2: mmap, 3: mmap+touch", ) parser.add_argument( - "--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to loose files or keep MMseqs2 databases." + "--unpack", type=int, default=1, choices=[0, 1], help="Unpack results to a3m text files or keep MMseqs2 databases." + ) + parser.add_argument( + "--merge-a3m", type=int, default=1, choices=[0, 1], help="Merge unpacked a3m files into a single a3m file." ) parser.add_argument( "--threads", type=int, default=64, help="Number of threads to use." @@ -522,7 +525,7 @@ def main(): unpack=args.unpack, ) - if args.unpack or args.af3_json: + if args.merge_a3m or args.af3_json: id = 0 for job_number, ( raw_jobname, @@ -537,22 +540,19 @@ def main(): for seq in query_sequences: with args.base.joinpath(f"{id}.a3m").open("r") as f: unpaired_msa.append(f.read()) - if args.af3_json: - args.base.joinpath(f"{id}.a3m").unlink() - - if args.use_env_pairing: - with open(args.base.joinpath(f"{id}.paired.a3m"), 'a') as file_pair: - with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env: - while chunk := file_pair_env.read(10 * 1024 * 1024): - file_pair.write(chunk) - if args.unpack: - args.base.joinpath(f"{id}.env.paired.a3m").unlink() + args.base.joinpath(f"{id}.a3m").unlink() if len(query_seqs_cardinality) > 1: with args.base.joinpath(f"{id}.paired.a3m").open("r") as f: - paired_msa.append(f.read()) - if args.unpack: - args.base.joinpath(f"{id}.paired.a3m").unlink() + paired = f.read() + if args.merge_a3m: + args.base.joinpath(f"{id}.paired.a3m").unlink() + if args.use_env_pairing: + with open(args.base.joinpath(f"{id}.env.paired.a3m"), 'r') as file_pair_env: + paired += file_pair_env.read() + if args.merge_a3m: + args.base.joinpath(f"{id}.env.paired.a3m").unlink() + paired_msa.append(paired) id += 1 if args.af3_json: @@ -560,7 +560,7 @@ def main(): with open(args.base.joinpath(f"{job_number}.json"), 'w') as f: f.write(json.dumps(af3.content, indent=4)) - if args.unpack: + if args.merge_a3m: msa = msa_to_str( unpaired_msa, paired_msa, query_sequences, query_seqs_cardinality ) @@ -583,10 +583,34 @@ def main(): if args.unpack: # rename a3m files for job_number, (raw_jobname, query_sequences, query_seqs_cardinality, other_molecules) in enumerate(queries_unique): - os.rename( - args.base.joinpath(f"{job_number}.a3m"), - args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), - ) + if args.merge_a3m: + os.rename( + args.base.joinpath(f"{job_number}.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), + ) + else: + # Rename unpaired, paired and env paired files. + nseqs = len(query_seqs_cardinality) + for id in range(job_number * nseqs, (job_number+1) * nseqs): + if nseqs > 1: + os.rename( + args.base.joinpath(f"{id}.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}_{id}.a3m"), + ) + os.rename( + args.base.joinpath(f"{id}.paired.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}_{id}.paired.a3m"), + ) + if args.use_env_pairing: + os.rename( + args.base.joinpath(f"{id}.env.paired.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}_{id}.env.paired.a3m"), + ) + else: + os.rename( + args.base.joinpath(f"{id}.a3m"), + args.base.joinpath(f"{safe_filename(raw_jobname)}.a3m"), + ) # rename m8 files if args.use_templates: