openfold/scripts/generate_chain_data_cache.py

140 lines
4.5 KiB
Python

import argparse
from functools import partial
import json
import logging
from multiprocessing import Pool
import os
import string
import sys
sys.path.append(".") # an innocent hack to get this to run from the top level
from collections import defaultdict
from tqdm import tqdm
from openfold.data.mmcif_parsing import parse
from openfold.np import protein, residue_constants
def parse_file(
f,
args,
chain_cluster_size_dict
):
file_id, ext = os.path.splitext(f)
if(ext == ".cif"):
with open(os.path.join(args.data_dir, f), "r") as fp:
mmcif_string = fp.read()
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if mmcif.mmcif_object is None:
logging.info(f"Could not parse {f}. Skipping...")
return {}
else:
mmcif = mmcif.mmcif_object
out = {}
for chain_id, seq in mmcif.chain_to_seqres.items():
full_name = "_".join([file_id, chain_id])
out[full_name] = {}
local_data = out[full_name]
local_data["release_date"] = mmcif.header["release_date"]
local_data["seq"] = seq
local_data["resolution"] = mmcif.header["resolution"]
if(chain_cluster_size_dict is not None):
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
local_data["cluster_size"] = cluster_size
elif(ext == ".pdb"):
with open(os.path.join(args.data_dir, f), "r") as fp:
pdb_string = fp.read()
protein_object = protein.from_pdb_string(pdb_string, None)
aatype = protein_object.aatype
chain_index = protein_object.chain_index
chain_dict = defaultdict(list)
for i in range(aatype.shape[0]):
chain_dict[chain_index[i]].append(residue_constants.restypes_with_x[aatype[i]])
out = {}
chain_tags = string.ascii_uppercase
for chain, seq in chain_dict.items():
full_name = "_".join([file_id, chain_tags[chain]])
out[full_name] = {}
local_data = out[full_name]
local_data["resolution"] = 0.
local_data["seq"] = ''.join(seq)
if(chain_cluster_size_dict is not None):
cluster_size = chain_cluster_size_dict.get(
full_name.upper(), -1
)
local_data["cluster_size"] = cluster_size
return out
def main(args):
chain_cluster_size_dict = None
if(args.cluster_file is not None):
chain_cluster_size_dict = {}
with open(args.cluster_file, "r") as fp:
clusters = [l.strip() for l in fp.readlines()]
for cluster in clusters:
chain_ids = cluster.split()
cluster_len = len(chain_ids)
for chain_id in chain_ids:
chain_id = chain_id.upper()
chain_cluster_size_dict[chain_id] = cluster_len
accepted_exts = [".cif", ".pdb"]
files = list(os.listdir(args.data_dir))
files = [f for f in files if os.path.splitext(f)[-1] in accepted_exts]
fn = partial(
parse_file,
args=args,
chain_cluster_size_dict=chain_cluster_size_dict,
)
data = {}
with Pool(processes=args.no_workers) as p:
with tqdm(total=len(files)) as pbar:
for d in p.imap_unordered(fn, files, chunksize=args.chunksize):
data.update(d)
pbar.update()
with open(args.output_path, "w") as fp:
fp.write(json.dumps(data, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"data_dir", type=str, help="Directory containing mmCIF or PDB files"
)
parser.add_argument(
"output_path", type=str, help="Path for .json output"
)
parser.add_argument(
"--cluster_file", type=str, default=None,
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument(
"--no_workers", type=int, default=4,
help="Number of workers to use for parsing"
)
parser.add_argument(
"--chunksize", type=int, default=10,
help="How many files should be distributed to each worker at a time"
)
args = parser.parse_args()
main(args)