Add improved alignment-db creation script
- much faster due to the use of threading and mp - also supports sharding
This commit is contained in:
parent
afd919825b
commit
a3bb3c40c7
|
@ -0,0 +1,193 @@
|
|||
"""
|
||||
This is a modified version of the create_alignment_db.py script in OpenFold
|
||||
which supports sharding into multiple files. The created index is already a
|
||||
super index, meaning that "unify_alignment_db_indices.py" does not need to be
|
||||
run on the output index. Additionally this script uses threading and
|
||||
multiprocessing and is much faster than the old version.
|
||||
"""
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
from math import ceil
|
||||
|
||||
|
||||
def split_file_list(file_list, n_shards):
|
||||
"""
|
||||
Split up the total file list into n_shards sublists.
|
||||
"""
|
||||
split_list = []
|
||||
|
||||
for i in range(n_shards):
|
||||
split_list.append(file_list[i::n_shards])
|
||||
|
||||
assert len([f for sublist in split_list for f in sublist]) == len(file_list)
|
||||
|
||||
return split_list
|
||||
|
||||
|
||||
def chunked_iterator(lst, chunk_size):
|
||||
"""Iterate over a list in chunks of size chunk_size."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def read_chain_dir(chain_dir) -> dict:
|
||||
"""
|
||||
Read all alignment files in a single chain directory and return a dict
|
||||
mapping chain name to file names and bytes.
|
||||
"""
|
||||
if not chain_dir.is_dir():
|
||||
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
|
||||
|
||||
# ensure that PDB IDs are all lowercase
|
||||
pdb_id, chain = chain_dir.name.split("_")
|
||||
pdb_id = pdb_id.lower()
|
||||
chain_name = f"{pdb_id}_{chain}"
|
||||
|
||||
|
||||
file_data = []
|
||||
|
||||
for file_path in sorted(chain_dir.iterdir()):
|
||||
file_name = file_path.name
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_bytes = file.read()
|
||||
|
||||
file_data.append((file_name, file_bytes))
|
||||
|
||||
return {chain_name: file_data}
|
||||
|
||||
|
||||
def process_chunk(chain_files: List[Path]) -> dict:
|
||||
"""
|
||||
Returns the file names and bytes for all chains in a chunk of files.
|
||||
"""
|
||||
chunk_data = {}
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for file_data in executor.map(read_chain_dir, chain_files):
|
||||
chunk_data.update(file_data)
|
||||
|
||||
return chunk_data
|
||||
|
||||
|
||||
def create_index_default_dict() -> dict:
|
||||
"""
|
||||
Returns a default dict for the index entries).
|
||||
"""
|
||||
return {"db": None, "files": []}
|
||||
|
||||
|
||||
def create_shard(
|
||||
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a single shard of the alignment database, and returns the
|
||||
corresponding indices for the super index.
|
||||
"""
|
||||
CHUNK_SIZE = 200
|
||||
shard_index = defaultdict(
|
||||
create_index_default_dict
|
||||
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
|
||||
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
|
||||
|
||||
pbar_desc = f"Shard {shard_num}"
|
||||
output_path = output_dir / f"{output_name}_{shard_num}.db"
|
||||
|
||||
db_offset = 0
|
||||
db_file = open(output_path, "wb")
|
||||
for files_chunk in tqdm(
|
||||
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
|
||||
):
|
||||
# get processed files for one chunk
|
||||
chunk_data = process_chunk(files_chunk)
|
||||
|
||||
# write to db and store info in index
|
||||
for chain_name, file_data in chunk_data.items():
|
||||
shard_index[chain_name]["db"] = output_path.name
|
||||
|
||||
for file_name, file_bytes in file_data:
|
||||
file_length = len(file_bytes)
|
||||
shard_index[chain_name]["files"].append(
|
||||
(file_name, db_offset, file_length)
|
||||
)
|
||||
db_file.write(file_bytes)
|
||||
db_offset += file_length
|
||||
db_file.close()
|
||||
|
||||
return shard_index
|
||||
|
||||
|
||||
def main(args):
|
||||
alignment_dir = args.alignment_dir
|
||||
output_dir = args.output_db_path
|
||||
output_db_name = args.output_db_name
|
||||
n_shards = args.n_shards
|
||||
|
||||
# get all chain dirs in alignment_dir
|
||||
print("Getting chain directories...")
|
||||
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
|
||||
|
||||
# split chain dirs into n_shards sublists
|
||||
chain_dir_shards = split_file_list(all_chain_dirs, n_shards)
|
||||
|
||||
# total index for all shards
|
||||
super_index = {}
|
||||
|
||||
# create a shard for each sublist
|
||||
print(f"Creating {n_shards} alignment-db files...")
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
create_shard, shard_files, output_dir, output_db_name, shard_index
|
||||
)
|
||||
for shard_index, shard_files in enumerate(chain_dir_shards)
|
||||
]
|
||||
|
||||
for future in as_completed(futures):
|
||||
shard_index = future.result()
|
||||
super_index.update(shard_index)
|
||||
print("\nCreated all shards.")
|
||||
|
||||
# write super index to file
|
||||
print("\nWriting super index...")
|
||||
index_path = output_dir / f"{output_db_name}.index"
|
||||
with open(index_path, "w") as fp:
|
||||
json.dump(super_index, fp, indent=4)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script creates an alignment database format from a directory of
|
||||
precomputed alignments. For better file system health, the total
|
||||
database is split into n_shards files, where each shard contains a
|
||||
subset of the total alignments. The output is a directory containing the
|
||||
n_shards database files, and a single index file mapping chain names to
|
||||
the database file and byte offsets for each alignment file.
|
||||
|
||||
Note: For optimal performance, your machine should have at least as many
|
||||
cores as shards you want to create.
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"alignment_dir",
|
||||
type=Path,
|
||||
help="""Path to precomputed alignment directory, with one subdirectory
|
||||
per chain.""",
|
||||
)
|
||||
parser.add_argument("output_db_path", type=Path)
|
||||
parser.add_argument("output_db_name", type=str)
|
||||
parser.add_argument(
|
||||
"n_shards", type=int, help="Number of shards to split the database into"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
Loading…
Reference in New Issue