7.2 KiB
Training OpenFold
Background
This guide covers how to train an OpenFold model for monomers. Some additional instructions are provided at the end for fine-tuning your model.
Pre-requisites:
This guide requires the following:
- Installation of OpenFold and dependencies (Including jackhmmer and hhblits depedencies)
- A preprocessed dataset:
- For this guide, we will use the original OpenFold dataset which is available on RODA, processed with these instructions.
- GPUs configured with CUDA. Training OpenFold with CPUs only is not supported.
Training a new OpenFold model
Basic command
For a dataset that has the default alignment file structure, e.g.
-$DATA_DIR
└── pdb_data
├── mmcifs
├── 3lrm.cif
└── 6kwc.cif
...
├── obsolete.dat
├── duplicate_pdb_chains.txt
└── data_caches
├── duplicate_pdb_chains.txt
└── data_caches
└── alignment_data
└── alignments
├── 3lrm_A/
├── 3lrm_B/
└── 6kwc_A/
...
The basic command to train a new OpenFold model is:
python3 train_openfold.py $DATA_DIR/pdb/mmcifs $DATA_DIR/alignment_data/alignments $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
--config_preset initial_training \
--seed 42 \
--obsolete_pdbs_file_path $DATA_DIR/pdb_data/obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4
The required arguments are:
mmcif_dir
: Mmcif files for the training set.alignments_dir
: Alignments for the sequences inmmcif_dir
, see expected directory structuretemplate_mmcif_dir
: Template mmcif files with structures, which can be the same directory as mmcif_dir. Themax_template_date
andtemplate_release_dates_cache_path
will specify which templates will be allowed based on a date cutoffoutput_dir
: Where model checkpoint files and other outputs will be saved.
Commonly used flags include:
config_preset
: Specifies which selection of hyperparameters should be used for initial model training. Commonly used configs are defined inopenfold/config.py
num_nodes
andgpus
: Specifies number of nodes and GPUs available to train OpenFold.seed
- Specifies random seednum_workers
: Number of CPU workers to assign for creating dataset examplesobsolete_pdbs_file_path
: Specifies obsolete pdb IDs that should be excluded from training.val_data_dir
andval_alignment_dir
: Specifies data directory and alignments for validation dataset.
Note that `--seed` must be specified to correctly configure training examples on multi-GPU training runs
Train with OpenFold Dataset Configuration
If the OpenFold alignment database setup is used, resulting in a data directory such as:
- $DATA_DIR
├── duplicate_pdb_chains.txt
├── pdb_data
└── mmcifs
├── 3lrm.cif
└── 6kwc.cif
└── alignment_data
└── alignment_db
├── alignment_db_0.db
├── alignment_db_1.db
...
├── alignment_db_9.db
└── alignment_db.index
The training command will use the alignment_index_path
argument to specify db.index
files, e.g.:
python3 train_openfold.py $DATA_DIR/pdb_data/mmcifs $DATA_DIR/alignment_data/alignment_db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
--alignment_index_path $DATA_DIR/pdb/alignment_db.index
--config_preset initial_training \
--seed 42 \
--obsolete_pdbs_file_path $DATA_DIR/pdb/obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4
Additional command line flag options:
Here we provide brief descriptions for customizing your training run of OpenFold. A full description of all flags can be accessed by using the --help
option in the script
-
Use Deepspeed acceleration strategy:
--deepspeed_config
This option configures OpenFold to use custom Deepspeed kernels. This option requires a deepspeed_config.json, you can create your own, or use the one in the OpenFold directory -
Use a validation dataset: Specify validation database paths with
--val_data_dir
+--val_alignment_dir
. Validation metrics will be evaluated on these datasets. -
Use a self-distillation dataset: Specify paths with
--distillation_data_dir
and--distillation_alignment_dir
flags -
Change specific parameters in the model or data setup:
--experiment_config_json
. These parameters must be defined in theopenfold/config.py
. For example to change the crop size for training a model, you can write the following json:{ "data.train.crop_size": 128 }
-
Configure training settings with PyTorch Lightning
Some flags e.g.
--precision
,--max_epochs
configure training behavior. See the Pytorch Lightning Trainer args section in the--help
menu for more information and consult Pytorch lightning documentation- Precision: On A100s, OpenFold training works best with bfloat 16 precision (e.g.
--precision bf16-mixed
)
- Precision: On A100s, OpenFold training works best with bfloat 16 precision (e.g.
-
Restart training from an existing checkpoint: Use the
--resume_from_ckpt
to restart training from an existing checkpoint.
Advanced Training Configurations
:::
Fine tuning from existing model weights
If you have existing model weights, you can fine tune the model by specifying a checkpoint path with --resume_from_ckpt
and --resume_model_weights_only
arguments, e.g.
python3 train_openfold.py $DATA_DIR/mmcifs $DATA_DIR/alignment.db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path chain_data_cache.json \
--template_release_dates_cache_path mmcif_cache.json \
--config_preset finetuning \
--alignment_index_path $DATA_DIR/pdb/alignment_db.index \
--seed 4242022 \
--obsolete_pdbs_file_path obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4 \
--resume_from_ckpt $CHECKPOINT_PATH \
--resume_model_weights_only
If you have model parameters from OpenFold v1.x, you may need to convert your checkpoint file or parameter. See Converting OpenFold v1 Weights for more details.
Using MPI
If MPI is configured on your system, and you would like to use MPI to train OpenFold models, you may do so with the following step:
- Add the
mpi4py
package, which are available through pip and conda. Please see mpi4py documentation for more instructions on installation. - Add the
--mpi_plugin
flag to your training command.
Training Multimer models
Coming soon.