173 lines
5.9 KiB
Python
173 lines
5.9 KiB
Python
import argparse
|
|
import os
|
|
import logging
|
|
import random
|
|
|
|
import numpy
|
|
import torch
|
|
from openfold.config import model_config
|
|
from openfold.data import feature_pipeline
|
|
from openfold.data.data_pipeline import make_sequence_features_with_custom_template
|
|
from openfold.np import protein
|
|
from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
|
|
relax_protein
|
|
from openfold.utils.tensor_utils import (
|
|
tensor_tree_map,
|
|
)
|
|
from scripts.utils import add_data_args
|
|
|
|
logging.basicConfig()
|
|
logger = logging.getLogger(__file__)
|
|
logger.setLevel(level=logging.INFO)
|
|
|
|
torch_versions = torch.__version__.split(".")
|
|
torch_major_version = int(torch_versions[0])
|
|
torch_minor_version = int(torch_versions[1])
|
|
if(
|
|
torch_major_version > 1 or
|
|
(torch_major_version == 1 and torch_minor_version >= 12)
|
|
):
|
|
# Gives a large speedup on Ampere-class GPUs
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
torch.set_grad_enabled(False)
|
|
|
|
|
|
def main(args):
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
config = model_config(args.config_preset)
|
|
|
|
random_seed = args.data_random_seed
|
|
if random_seed is None:
|
|
random_seed = random.randrange(2**32)
|
|
|
|
numpy.random.seed(random_seed)
|
|
torch.manual_seed(random_seed + 1)
|
|
feature_processor = feature_pipeline.FeaturePipeline(config.data)
|
|
|
|
with open(args.input_fasta) as fasta_file:
|
|
tags, sequences = parse_fasta(fasta_file.read())
|
|
|
|
if len(sequences) != 1:
|
|
raise ValueError("the threading script can only process a single sequence")
|
|
|
|
query_sequence = sequences[0]
|
|
query_tag = tags[0]
|
|
feature_dict = make_sequence_features_with_custom_template(
|
|
query_sequence,
|
|
args.input_mmcif,
|
|
args.template_id,
|
|
args.chain_id,
|
|
args.kalign_binary_path)
|
|
processed_feature_dict = feature_processor.process_features(
|
|
feature_dict, mode='predict',
|
|
)
|
|
processed_feature_dict = {
|
|
k: torch.as_tensor(v, device=args.model_device)
|
|
for k, v in processed_feature_dict.items()
|
|
}
|
|
|
|
model_generator = load_models_from_command_line(
|
|
config,
|
|
args.model_device,
|
|
args.openfold_checkpoint_path,
|
|
args.jax_param_path,
|
|
args.output_dir)
|
|
output_name = f'{query_tag}_{args.config_preset}'
|
|
for model, output_directory in model_generator:
|
|
out = run_model(model, processed_feature_dict, query_tag, args.output_dir)
|
|
|
|
# Toss out the recycling dimensions --- we don't need them anymore
|
|
processed_feature_dict = tensor_tree_map(
|
|
lambda x: numpy.array(x[..., -1].cpu()),
|
|
processed_feature_dict
|
|
)
|
|
out = tensor_tree_map(lambda x: numpy.array(x.cpu()), out)
|
|
|
|
|
|
unrelaxed_protein = prep_output(
|
|
out,
|
|
processed_feature_dict,
|
|
feature_dict,
|
|
feature_processor,
|
|
args.config_preset,
|
|
200, # this is the ri_multimer_gap. There's no multimer sequences here, so it doesnt matter what its set to
|
|
args.subtract_plddt
|
|
)
|
|
|
|
unrelaxed_output_path = os.path.join(
|
|
output_directory, f'{output_name}_unrelaxed.pdb'
|
|
)
|
|
|
|
with open(unrelaxed_output_path, 'w') as fp:
|
|
fp.write(protein.to_pdb(unrelaxed_protein))
|
|
|
|
logger.info(f"Output written to {unrelaxed_output_path}...")
|
|
|
|
logger.info(f"Running relaxation on {unrelaxed_output_path}...")
|
|
relax_protein(config, args.model_device, unrelaxed_protein, output_directory, output_name, False)
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("input_fasta", type=str, help="the path to a fasta file containing a single sequence to thread")
|
|
parser.add_argument("input_mmcif", type=str, help="the path to an mmcif file to thread the sequence on to")
|
|
|
|
parser.add_argument("--template_id", type=str, help="a PDB id or other identifier for the template")
|
|
|
|
parser.add_argument(
|
|
"--chain_id", type=str,
|
|
help="""The chain ID of the chain in the template to use"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--model_device", type=str, default="cpu",
|
|
help="""Name of the device on which to run the model. Any valid torch
|
|
device name is accepted (e.g. "cpu", "cuda:0")"""
|
|
)
|
|
parser.add_argument(
|
|
"--config_preset", type=str, default="model_1",
|
|
help="""Name of a model config preset defined in openfold/config.py"""
|
|
)
|
|
parser.add_argument(
|
|
"--jax_param_path", type=str, default=None,
|
|
help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
|
|
is also None, parameters are selected automatically according to
|
|
the model name from openfold/resources/params"""
|
|
)
|
|
parser.add_argument(
|
|
"--openfold_checkpoint_path", type=str, default=None,
|
|
help="""Path to OpenFold checkpoint. Can be either a DeepSpeed
|
|
checkpoint directory or a .pt file"""
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir", type=str, default=os.getcwd(),
|
|
help="""Name of the directory in which to output the prediction""",
|
|
)
|
|
parser.add_argument(
|
|
"--subtract_plddt", action="store_true", default=False,
|
|
help=""""Whether to output (100 - pLDDT) in the B-factor column instead
|
|
of the pLDDT itself"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--data_random_seed", type=str, default=None
|
|
)
|
|
|
|
add_data_args(parser)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
|
|
args.jax_param_path = os.path.join(
|
|
"openfold", "resources", "params",
|
|
"params_" + args.config_preset + ".npz"
|
|
)
|
|
|
|
if(args.model_device == "cpu" and torch.cuda.is_available()):
|
|
logging.warning(
|
|
"""The model is being run on CPU. Consider specifying
|
|
--model_device for better performance"""
|
|
)
|
|
|
|
main(args) |