openfold/docs/source/Training_OpenFold.md

170 lines
7.2 KiB
Markdown

# Training OpenFold
## Background
This guide covers how to train an OpenFold model for monomers. Some additional instructions are provided at the end for fine-tuning your model.
### Pre-requisites:
This guide requires the following:
- [Installation of OpenFold and dependencies](Installation.md) (Including jackhmmer and hhblits depedencies)
- A preprocessed dataset:
- For this guide, we will use the original OpenFold dataset which is available on RODA, processed with [these instructions](OpenFold_Training_Setup.md).
- GPUs configured with CUDA. Training OpenFold with CPUs only is not supported.
## Training a new OpenFold model
#### Basic command
For a dataset that has the default alignment file structure, e.g.
```
-$DATA_DIR
└── pdb_data
├── mmcifs
├── 3lrm.cif
└── 6kwc.cif
...
├── obsolete.dat
├── duplicate_pdb_chains.txt
└── data_caches
├── duplicate_pdb_chains.txt
└── data_caches
└── alignment_data
└── alignments
├── 3lrm_A/
├── 3lrm_B/
└── 6kwc_A/
...
```
The basic command to train a new OpenFold model is:
```
python3 train_openfold.py $DATA_DIR/pdb/mmcifs $DATA_DIR/alignment_data/alignments $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
--config_preset initial_training \
--seed 42 \
--obsolete_pdbs_file_path $DATA_DIR/pdb_data/obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4
```
The required arguments are:
- `mmcif_dir` : Mmcif files for the training set.
- `alignments_dir`: Alignments for the sequences in `mmcif_dir`, see expected directory structure
- `template_mmcif_dir`: Template mmcif files with structures, which can be the same directory as mmcif_dir. The `max_template_date` and `template_release_dates_cache_path` will specify which templates will be allowed based on a date cutoff
- `output_dir` : Where model checkpoint files and other outputs will be saved.
Commonly used flags include:
- `config_preset`: Specifies which selection of hyperparameters should be used for initial model training. Commonly used configs are defined in [`openfold/config.py`](https://github.com/aqlaboratory/openfold)
- `num_nodes` and `gpus`: Specifies number of nodes and GPUs available to train OpenFold.
- `seed` - Specifies random seed
- `num_workers`: Number of CPU workers to assign for creating dataset examples
- `obsolete_pdbs_file_path`: Specifies obsolete pdb IDs that should be excluded from training.
- `val_data_dir` and `val_alignment_dir`: Specifies data directory and alignments for validation dataset.
```{note}
Note that `--seed` must be specified to correctly configure training examples on multi-GPU training runs
```
#### Train with OpenFold Dataset Configuration
If the [OpenFold alignment database](OpenFold_Training_Setup.md#2-creating-alignment-dbs-optional) setup is used, resulting in a data directory such as:
```
- $DATA_DIR
├── duplicate_pdb_chains.txt
├── pdb_data
└── mmcifs
├── 3lrm.cif
└── 6kwc.cif
└── alignment_data
└── alignment_db
├── alignment_db_0.db
├── alignment_db_1.db
...
├── alignment_db_9.db
└── alignment_db.index
```
The training command will use the `alignment_index_path` argument to specify `db.index` files, e.g.:
```
python3 train_openfold.py $DATA_DIR/pdb_data/mmcifs $DATA_DIR/alignment_data/alignment_db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
--alignment_index_path $DATA_DIR/pdb/alignment_db.index
--config_preset initial_training \
--seed 42 \
--obsolete_pdbs_file_path $DATA_DIR/pdb/obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4
```
#### Additional command line flag options:
Here we provide brief descriptions for customizing your training run of OpenFold. A full description of all flags can be accessed by using the `--help` option in the script
- **Use Deepspeed acceleration strategy:** `--deepspeed_config` This option configures OpenFold to use custom Deepspeed kernels. This option requires a deepspeed_config.json, you can create your own, or use the one in the OpenFold directory
- **Use a validation dataset:** Specify validation database paths with `--val_data_dir` + `--val_alignment_dir`. Validation metrics will be evaluated on these datasets.
- **Use a self-distillation dataset:** Specify paths with `--distillation_data_dir` and `--distillation_alignment_dir` flags
- **Change specific parameters in the model or data setup:** `--experiment_config_json`. These parameters must be defined in the [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py). For example to change the crop size for training a model, you can write the following json:
```cropsize.json
{
"data.train.crop_size": 128
}
```
- **Configure training settings with PyTorch Lightning**
Some flags e.g. `--precision`, `--max_epochs` configure training behavior. See the Pytorch Lightning Trainer args section in the `--help` menu for more information and consult [Pytorch lightning documentation](https://lightning.ai/docs/pytorch/stable/)
- Precision: On A100s, OpenFold training works best with bfloat 16 precision (e.g. `--precision bf16-mixed`)
- **Restart training from an existing checkpoint:** Use the `--resume_from_ckpt` to restart training from an existing checkpoint.
## Advanced Training Configurations
:::
### Fine tuning from existing model weights
If you have existing model weights, you can fine tune the model by specifying a checkpoint path with `--resume_from_ckpt` and `--resume_model_weights_only` arguments, e.g.
```
python3 train_openfold.py $DATA_DIR/mmcifs $DATA_DIR/alignment.db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
--max_template_date 2021-10-10 \
--train_chain_data_cache_path chain_data_cache.json \
--template_release_dates_cache_path mmcif_cache.json \
--config_preset finetuning \
--alignment_index_path $DATA_DIR/pdb/alignment_db.index \
--seed 4242022 \
--obsolete_pdbs_file_path obsolete.dat \
--num_nodes 1 \
--gpus 4 \
--num_workers 4 \
--resume_from_ckpt $CHECKPOINT_PATH \
--resume_model_weights_only
```
If you have model parameters from OpenFold v1.x, you may need to convert your checkpoint file or parameter. See [Converting OpenFold v1 Weights](convert_of_v1_weights.md) for more details.
### Using MPI
If MPI is configured on your system, and you would like to use MPI to train OpenFold models, you may do so with the following step:
1. Add the `mpi4py` package, which are available through pip and conda. Please see [mpi4py documentation](https://pypi.org/project/mpi4py/) for more instructions on installation.
2. Add the `--mpi_plugin` flag to your training command.
### Training Multimer models
```{note}
Coming soon.
```