openfold/train_openfold.py

702 lines
25 KiB
Python

import argparse
import logging
import os
import sys
import json
import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.plugins.environments import MPIEnvironment
from pytorch_lightning import seed_everything
import torch
import wandb
from deepspeed.utils import zero_to_fp32
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import (
EarlyStoppingVerbose,
)
from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
drmsd,
gdt_ts,
gdt_ha,
)
from openfold.utils.import_weights import (
import_jax_weights_,
import_openfold_weights_
)
from openfold.utils.logger import PerformanceLoggingCallback
class OpenFoldWrapper(pl.LightningModule):
def __init__(self, config):
super(OpenFoldWrapper, self).__init__()
self.config = config
self.model = AlphaFold(config)
self.is_multimer = self.config.globals.is_multimer
self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay
)
self.cached_weights = None
self.last_lr_step = -1
self.save_hyperparameters()
def forward(self, batch):
return self.model(batch)
def _log(self, loss_breakdown, batch, outputs, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
prog_bar=(loss_name == 'loss'),
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
)
if(train):
self.log(
f"{phase}/{loss_name}_epoch",
indiv_loss,
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
torch.mean(v),
prog_bar = (k == 'loss'),
on_step=False, on_epoch=True, logger=True, sync_dist=False,
)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
return loss
def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model)
def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None)
# Run the model
outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch)
batch["use_clamped_fape"] = 0.
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss and other metrics
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
def on_validation_epoch_end(self):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
self.cached_weights = None
def _compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca, # still required here to compute n
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, alignment_rmsd = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca,
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["alignment_rmsd"] = alignment_rmsd
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ha"] = gdt_ha_score
return metrics
def configure_optimizers(self,
learning_rate: float = 1e-3,
eps: float = 1e-5,
) -> torch.optim.Adam:
# Ignored as long as a DeepSpeed optimizer is configured
optimizer = torch.optim.Adam(
self.model.parameters(),
lr=learning_rate,
eps=eps
)
if self.last_lr_step != -1:
for group in optimizer.param_groups:
if 'initial_lr' not in group:
group['initial_lr'] = learning_rate
lr_scheduler = AlphaFoldLRScheduler(
optimizer,
last_epoch=self.last_lr_step
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"name": "AlphaFoldLRScheduler",
}
}
def on_load_checkpoint(self, checkpoint):
ema = checkpoint["ema"]
if(not self.model.template_config.enabled):
ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k}
self.ema.load_state_dict(ema)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema"] = self.ema.state_dict()
def resume_last_lr_step(self, lr_step):
self.last_lr_step = lr_step
def load_from_jax(self, jax_path):
model_basename = os.path.splitext(
os.path.basename(
os.path.normpath(jax_path)
)
)[0]
model_version = "_".join(model_basename.split("_")[1:])
import_jax_weights_(
self.model, jax_path, version=model_version
)
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
latest_path = os.path.join(checkpoint_dir, 'latest')
if os.path.isfile(latest_path):
with open(latest_path, 'r') as fd:
tag = fd.read().strip()
else:
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
return torch.load(state_file)
def main(args):
if(args.seed is not None):
seed_everything(args.seed, workers=True)
is_low_precision = args.precision in [
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
config = model_config(
args.config_preset,
train=True,
low_prec=is_low_precision,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
model_module = OpenFoldWrapper(config)
if args.resume_from_ckpt:
if args.resume_model_weights_only:
# Load the checkpoint
if os.path.isdir(args.resume_from_ckpt):
sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
# Process the state dict
if 'module' in sd:
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
import_openfold_weights_(model=model_module, state_dict=sd)
elif 'state_dict' in sd:
import_openfold_weights_(
model=model_module, state_dict=sd['state_dict'])
else:
# Loading from pre-trained model
sd = {'model.'+k: v for k, v in sd.items()}
import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...")
else: # Loads a checkpoint to start from a specific time step
if os.path.isdir(args.resume_from_ckpt):
sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
else:
sd = torch.load(args.resume_from_ckpt)
last_global_step = int(sd['global_step'])
model_module.resume_last_lr_step(last_global_step)
logging.info("Successfully loaded last lr step...")
if args.resume_from_jax_params:
model_module.load_from_jax(args.resume_from_jax_params)
logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
# TorchScript components of the model
if(args.script_modules):
script_preset_(model_module)
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data()
data_module.setup()
callbacks = []
if(args.checkpoint_every_epoch):
mc = ModelCheckpoint(
every_n_epochs=1,
auto_insert_metric_name=False,
save_top_k=-1,
)
callbacks.append(mc)
if(args.early_stopping):
es = EarlyStoppingVerbose(
monitor="val/lddt_ca",
min_delta=args.min_delta,
patience=args.patience,
verbose=False,
mode="max",
check_finite=True,
strict=True,
)
callbacks.append(es)
if(args.log_performance):
global_batch_size = args.num_nodes * args.gpus
perf = PerformanceLoggingCallback(
log_file=os.path.join(args.output_dir, "performance_log.json"),
global_batch_size=global_batch_size,
)
callbacks.append(perf)
if(args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = []
is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
if(args.wandb):
if args.mpi_plugin and is_rank_zero:
wandb_init_dict = dict(
name=args.experiment_name,
project=args.wandb_project,
id=args.wandb_id,
dir=args.output_dir,
resume="allow",
anonymous=None,
entity=args.wandb_entity
)
wandb.run = wandb.init(**wandb_init_dict)
wdb_logger = WandbLogger(
name=args.experiment_name,
save_dir=args.output_dir,
id=args.wandb_id,
project=args.wandb_project,
**{"entity": args.wandb_entity}
)
loggers.append(wdb_logger)
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
if(args.deepspeed_config_path is not None):
strategy = DeepSpeedStrategy(
config=args.deepspeed_config_path,
cluster_environment=cluster_environment,
)
if(args.wandb and is_rank_zero):
wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPStrategy(find_unused_parameters=False,
cluster_environment=cluster_environment)
else:
strategy = None
if(args.wandb and is_rank_zero):
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
wdb_logger.experiment.save(f"{freeze_path}")
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
trainer_args.update({
'default_root_dir': args.output_dir,
'strategy': strategy,
'callbacks': callbacks,
'logger': loggers,
})
trainer = pl.Trainer(**trainer_args)
if (args.resume_model_weights_only):
ckpt_path = None
else:
ckpt_path = args.resume_from_ckpt
trainer.fit(
model_module,
datamodule=data_module,
ckpt_path=ckpt_path,
)
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"train_data_dir", type=str,
help="Directory containing training mmCIF files"
)
parser.add_argument(
"train_alignment_dir", type=str,
help="Directory containing precomputed training alignments"
)
parser.add_argument(
"template_mmcif_dir", type=str,
help="Directory containing mmCIF files to search for templates"
)
parser.add_argument(
"output_dir", type=str,
help='''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser.add_argument(
"max_template_date", type=str,
help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="Path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs."
)
parser.add_argument(
"--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files"
)
parser.add_argument(
"--distillation_alignment_dir", type=str, default=None,
help="Directory containing precomputed distillation alignments"
)
parser.add_argument(
"--val_data_dir", type=str, default=None,
help="Directory containing validation mmCIF files"
)
parser.add_argument(
"--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments"
)
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during validation"
)
parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary"
)
parser.add_argument(
"--train_filter_path", type=str, default=None,
help='''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser.add_argument(
"--distillation_filter_path", type=str, default=None,
help="""See --train_filter_path"""
)
parser.add_argument(
"--obsolete_pdbs_file_path", type=str, default=None,
help="""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser.add_argument(
"--template_release_dates_cache_path", type=str, default=None,
help="""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser.add_argument(
"--use_small_bfd", type=bool_type, default=False,
help="Whether to use a reduced version of the BFD database"
)
parser.add_argument(
"--seed", type=int, default=None,
help="Random seed"
)
parser.add_argument(
"--deepspeed_config_path", type=str, default=None,
help="Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser.add_argument(
"--checkpoint_every_epoch", action="store_true", default=False,
help="""Whether to checkpoint at the end of every training epoch"""
)
parser.add_argument(
"--early_stopping", type=bool_type, default=False,
help="Whether to stop training when validation loss fails to decrease"
)
parser.add_argument(
"--min_delta", type=float, default=0,
help="""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser.add_argument(
"--patience", type=int, default=3,
help="Early stopping patience"
)
parser.add_argument(
"--resume_from_ckpt", type=str, default=None,
help="Path to a model checkpoint from which to restore training state"
)
parser.add_argument(
"--resume_model_weights_only", type=bool_type, default=False,
help="Whether to load just model weights as opposed to training state"
)
parser.add_argument(
"--resume_from_jax_params", type=str, default=None,
help="""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser.add_argument(
"--log_performance", type=bool_type, default=False,
help="Measure performance"
)
parser.add_argument(
"--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
)
parser.add_argument(
"--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
)
parser.add_argument(
"--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
)
parser.add_argument(
"--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
help="Whether to TorchScript eligible components of them model"
)
parser.add_argument(
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_distillation_structure_index_path", type=str, default=None,
)
parser.add_argument(
"--alignment_index_path", type=str, default=None,
help="Training alignment index. See the README for instructions."
)
parser.add_argument(
"--distillation_alignment_index_path", type=str, default=None,
help="Distillation alignment index. See the README for instructions."
)
parser.add_argument(
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
)
parser.add_argument(
"--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
)
parser.add_argument("--mpi_plugin", action="store_true", default=False,
help="Whether to use MPI for parallele processing")
trainer_group = parser.add_argument_group(
'Arguments to pass to PyTorch Lightning Trainer')
trainer_group.add_argument(
"--num_nodes", type=int, default=1,
)
trainer_group.add_argument(
"--precision", type=str, default='bf16',
help='Sets precision, lower precision improves runtime performance.',
)
trainer_group.add_argument(
"--max_epochs", type=int, default=1,
)
trainer_group.add_argument(
"--log_every_n_steps", type=int, default=25,
)
trainer_group.add_argument(
"--flush_logs_every_n_steps", type=int, default=5,
)
trainer_group.add_argument(
"--num_sanity_val_steps", type=int, default=0,
)
trainer_group.add_argument(
"--reload_dataloaders_every_n_epochs", type=int, default=1,
)
trainer_group.add_argument("--accumulate_grad_batches", type=int, default=1,
help="Accumulate gradients over k batches before next optimizer step.")
args = parser.parse_args()
if(args.seed is None and
((args.gpus is not None and args.gpus > 1) or
(args.num_nodes is not None and args.num_nodes > 1))):
raise ValueError("For distributed training, --seed must be specified")
if(str(args.precision) == "16" and args.deepspeed_config_path is not None):
raise ValueError("DeepSpeed and FP16 training are not compatible")
if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None):
raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path")
main(args)