Add improved alignment-db creation script

- much faster due to the use of threading and mp
- also supports sharding
This commit is contained in:
Lukas Jarosch 2023-10-06 15:14:02 -07:00
parent afd919825b
commit a3bb3c40c7
1 changed files with 193 additions and 0 deletions

View File

@ -0,0 +1,193 @@
"""
This is a modified version of the create_alignment_db.py script in OpenFold
which supports sharding into multiple files. The created index is already a
super index, meaning that "unify_alignment_db_indices.py" does not need to be
run on the output index. Additionally this script uses threading and
multiprocessing and is much faster than the old version.
"""
import argparse
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
from math import ceil
def split_file_list(file_list, n_shards):
"""
Split up the total file list into n_shards sublists.
"""
split_list = []
for i in range(n_shards):
split_list.append(file_list[i::n_shards])
assert len([f for sublist in split_list for f in sublist]) == len(file_list)
return split_list
def chunked_iterator(lst, chunk_size):
"""Iterate over a list in chunks of size chunk_size."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def read_chain_dir(chain_dir) -> dict:
"""
Read all alignment files in a single chain directory and return a dict
mapping chain name to file names and bytes.
"""
if not chain_dir.is_dir():
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
# ensure that PDB IDs are all lowercase
pdb_id, chain = chain_dir.name.split("_")
pdb_id = pdb_id.lower()
chain_name = f"{pdb_id}_{chain}"
file_data = []
for file_path in sorted(chain_dir.iterdir()):
file_name = file_path.name
with open(file_path, "rb") as file:
file_bytes = file.read()
file_data.append((file_name, file_bytes))
return {chain_name: file_data}
def process_chunk(chain_files: List[Path]) -> dict:
"""
Returns the file names and bytes for all chains in a chunk of files.
"""
chunk_data = {}
with ThreadPoolExecutor() as executor:
for file_data in executor.map(read_chain_dir, chain_files):
chunk_data.update(file_data)
return chunk_data
def create_index_default_dict() -> dict:
"""
Returns a default dict for the index entries).
"""
return {"db": None, "files": []}
def create_shard(
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
) -> dict:
"""
Creates a single shard of the alignment database, and returns the
corresponding indices for the super index.
"""
CHUNK_SIZE = 200
shard_index = defaultdict(
create_index_default_dict
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
pbar_desc = f"Shard {shard_num}"
output_path = output_dir / f"{output_name}_{shard_num}.db"
db_offset = 0
db_file = open(output_path, "wb")
for files_chunk in tqdm(
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
):
# get processed files for one chunk
chunk_data = process_chunk(files_chunk)
# write to db and store info in index
for chain_name, file_data in chunk_data.items():
shard_index[chain_name]["db"] = output_path.name
for file_name, file_bytes in file_data:
file_length = len(file_bytes)
shard_index[chain_name]["files"].append(
(file_name, db_offset, file_length)
)
db_file.write(file_bytes)
db_offset += file_length
db_file.close()
return shard_index
def main(args):
alignment_dir = args.alignment_dir
output_dir = args.output_db_path
output_db_name = args.output_db_name
n_shards = args.n_shards
# get all chain dirs in alignment_dir
print("Getting chain directories...")
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
# split chain dirs into n_shards sublists
chain_dir_shards = split_file_list(all_chain_dirs, n_shards)
# total index for all shards
super_index = {}
# create a shard for each sublist
print(f"Creating {n_shards} alignment-db files...")
with ProcessPoolExecutor() as executor:
futures = [
executor.submit(
create_shard, shard_files, output_dir, output_db_name, shard_index
)
for shard_index, shard_files in enumerate(chain_dir_shards)
]
for future in as_completed(futures):
shard_index = future.result()
super_index.update(shard_index)
print("\nCreated all shards.")
# write super index to file
print("\nWriting super index...")
index_path = output_dir / f"{output_db_name}.index"
with open(index_path, "w") as fp:
json.dump(super_index, fp, indent=4)
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""
This script creates an alignment database format from a directory of
precomputed alignments. For better file system health, the total
database is split into n_shards files, where each shard contains a
subset of the total alignments. The output is a directory containing the
n_shards database files, and a single index file mapping chain names to
the database file and byte offsets for each alignment file.
Note: For optimal performance, your machine should have at least as many
cores as shards you want to create.
"""
)
parser.add_argument(
"alignment_dir",
type=Path,
help="""Path to precomputed alignment directory, with one subdirectory
per chain.""",
)
parser.add_argument("output_db_path", type=Path)
parser.add_argument("output_db_name", type=str)
parser.add_argument(
"n_shards", type=int, help="Number of shards to split the database into"
)
args = parser.parse_args()
main(args)