Compare commits
90 Commits
f545323c44
...
b7f35dce65
Author | SHA1 | Date |
---|---|---|
Christina Floristean | b7f35dce65 | |
Christina Floristean | 5aa549583a | |
Christina Floristean | 099769d2ec | |
Jennifer Wei | 2338b896c7 | |
dependabot[bot] | dcd809d9c2 | |
Jennifer Wei | 51472e756a | |
dependabot[bot] | 13728b1203 | |
Jennifer Wei | f7dba95f0b | |
Jennifer Wei | e3716118cd | |
Jennifer Wei | 6df89c763f | |
Jennifer Wei | 13f0f6fe16 | |
Sachin Kadyan | 9e32781fd6 | |
jnwei | f68a6c694b | |
jnwei | a5c69a79c6 | |
jnwei | e2bb3c4b90 | |
jnwei | 6fe34248b2 | |
jnwei | 5efba4425a | |
jnwei | 3817d94098 | |
jnwei | a5a86d4323 | |
Jennifer Wei | f06657fe8a | |
Sachin Kadyan | 2d4fe4f414 | |
jnwei | a90da39554 | |
Sachin Kadyan | 86b990d6ed | |
Sachin Kadyan | 8185c30775 | |
Sachin Kadyan | 4c8e37644e | |
jnwei | 5f5c8f2a5b | |
Matthew W. Thompson | 7666c80272 | |
Matthew W. Thompson | 582103505d | |
Matthew W. Thompson | 736d668741 | |
Matthew W. Thompson | 32c11376d7 | |
Matthew W. Thompson | f86d42f40e | |
Matthew W. Thompson | 6bf5c8cea1 | |
Sachin Kadyan | 92835fd5e6 | |
Sachin Kadyan | 0026173e23 | |
Sachin Kadyan | bcc6d97b69 | |
Gustaf Ahdritz | 0c20e3c989 | |
jnwei | d6ae9f5894 | |
Jennifer Wei | b3a118fc83 | |
Jennifer Wei | 2893fd934b | |
Jennifer Wei | 3e3f07c7f2 | |
Jennifer Wei | fcba33580e | |
Gustaf Ahdritz | 2300f6720d | |
Jennifer | d77a8dabea | |
Jennifer | fb34a0cb62 | |
Jennifer Wei | 705c26773d | |
Jennifer Wei | 4fde713c05 | |
Jennifer Wei | 7922bd57f1 | |
Sachin Kadyan | 6381ddd6e9 | |
Sachin Kadyan | e8de822e9b | |
Sachin Kadyan | c8c1239723 | |
Sachin Kadyan | b45a91ba5c | |
Sachin Kadyan | 3f592307eb | |
Gustaf Ahdritz | 6aefa986a8 | |
Sachin Kadyan | 3c240cb3f2 | |
Sachin Kadyan | 28334db382 | |
Sachin Kadyan | a7c0d0d178 | |
Sachin Kadyan | 777d738a59 | |
Sachin Kadyan | 6012b9e1c1 | |
Sachin Kadyan | 08ef6e9fb6 | |
Sachin Kadyan | 395a9f1ba8 | |
sachinkadyan7 | f85d67f4f9 | |
sachinkadyan7 | 9b114f28df | |
sachinkadyan7 | 5a8d2b78c1 | |
sachinkadyan7 | 36d5708cfd | |
Sachin Kadyan | 047e69af8d | |
Sachin Kadyan | 624b5aa698 | |
Sachin Kadyan | 299629903b | |
Sachin Kadyan | a83c6fcc3e | |
Sachin Kadyan | 8c94482aa0 | |
Sachin Kadyan | 380947c429 | |
Sachin Kadyan | f2540236b7 | |
Sachin Kadyan | 43d0964536 | |
Sachin Kadyan | 7f84eebd48 | |
Sachin Kadyan | a51f5fb585 | |
Sachin Kadyan | bbdaacfd17 | |
Sachin Kadyan | a6a467e09c | |
Sachin Kadyan | 2ba07feb88 | |
Sachin Kadyan | bc3ba06ef1 | |
Sachin Kadyan | 6403401fb6 | |
Sachin Kadyan | 75889e9a9a | |
Sachin Kadyan | aacf1b6fb2 | |
Sachin Kadyan | cf054ce9e3 | |
Sachin Kadyan | 19d090cb92 | |
Sachin Kadyan | e40900d897 | |
Sachin Kadyan | 21a88b6ff9 | |
Sachin Kadyan | 40325b186e | |
jnwei | 60d0b15ac3 | |
Jennifer Wei | 73ff40b655 | |
Jennifer Wei | 8baae516a4 | |
Jennifer Wei | 48668ca30b |
|
@ -0,0 +1,7 @@
|
|||
version: 2
|
||||
updates:
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
|
@ -10,6 +10,6 @@ jobs:
|
|||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
- name: Build the Docker image
|
||||
run: docker build . --file Dockerfile --tag openfold:$(date +%s)
|
|
@ -4,8 +4,8 @@ jobs:
|
|||
undefined_names:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
- run: pip install --upgrade pip
|
||||
- run: pip install flake8
|
||||
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
|
|
13
Dockerfile
13
Dockerfile
|
@ -1,4 +1,4 @@
|
|||
FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04
|
||||
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
|
||||
|
||||
# metainformation
|
||||
LABEL org.opencontainers.image.version = "1.0.0"
|
||||
|
@ -13,24 +13,23 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
|
|||
|
||||
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
|
||||
RUN wget -P /tmp \
|
||||
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \
|
||||
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \
|
||||
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh
|
||||
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \
|
||||
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
|
||||
&& rm /tmp/Miniforge3-Linux-x86_64.sh
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
COPY environment.yml /opt/openfold/environment.yml
|
||||
|
||||
# installing into the base environment since the docker container wont do anything other than run openfold
|
||||
RUN conda env update -n base --file /opt/openfold/environment.yml && conda clean --all
|
||||
RUN mamba env update -n base --file /opt/openfold/environment.yml && mamba clean --all
|
||||
RUN export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}
|
||||
|
||||
COPY openfold /opt/openfold/openfold
|
||||
COPY scripts /opt/openfold/scripts
|
||||
COPY run_pretrained_openfold.py /opt/openfold/run_pretrained_openfold.py
|
||||
COPY train_openfold.py /opt/openfold/train_openfold.py
|
||||
COPY setup.py /opt/openfold/setup.py
|
||||
COPY lib/openmm.patch /opt/openfold/lib/openmm.patch
|
||||
RUN wget -q -P /opt/openfold/openfold/resources \
|
||||
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
|
||||
RUN patch -p0 -d /opt/conda/lib/python3.9/site-packages/ < /opt/openfold/lib/openmm.patch
|
||||
WORKDIR /opt/openfold
|
||||
RUN python3 setup.py install
|
||||
|
|
87
README.md
87
README.md
|
@ -29,7 +29,7 @@ vice versa (see `scripts/convert_of_weights_to_jax.py`).
|
|||
|
||||
OpenFold has the following advantages over the reference implementation:
|
||||
|
||||
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs.
|
||||
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
|
||||
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
|
||||
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
|
||||
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
|
||||
|
@ -49,37 +49,19 @@ and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (night
|
|||
installed on your system. You'll need `git-lfs` to download OpenFold parameters.
|
||||
Finally, some download scripts require `aria2c` and `aws`.
|
||||
|
||||
For convenience, we provide a script that installs Miniconda locally, creates a
|
||||
`conda` virtual environment, installs all Python dependencies, and downloads
|
||||
useful resources, including both sets of model parameters. Run:
|
||||
This package is currently supported for CUDA 11 and Pytorch 1.12
|
||||
|
||||
```bash
|
||||
scripts/install_third_party_dependencies.sh
|
||||
```
|
||||
To install:
|
||||
1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
|
||||
1. From the `openfold` repo:
|
||||
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
|
||||
`mamba env create -n openfold_env -f environment.yml`
|
||||
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
|
||||
- Activate the environment, e.g `conda activate openfold_env`
|
||||
1. Run `scripts/install_third_party_dependencies.sh` to configure kernels and folding resources.
|
||||
|
||||
To activate the environment, run:
|
||||
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
|
||||
|
||||
```bash
|
||||
source scripts/activate_conda_env.sh
|
||||
```
|
||||
|
||||
To deactivate it, run:
|
||||
|
||||
```bash
|
||||
source scripts/deactivate_conda_env.sh
|
||||
```
|
||||
|
||||
With the environment active, compile OpenFold's CUDA kernels with
|
||||
|
||||
```bash
|
||||
python3 setup.py install
|
||||
```
|
||||
|
||||
To install the HH-suite to `/usr/bin`, run
|
||||
|
||||
```bash
|
||||
# scripts/install_hh_suite.sh
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -233,6 +215,51 @@ efficent AlphaFold-Multimer more than double the time. Use the
|
|||
at once. The `run_pretrained_openfold.py` script can enable this config option with the
|
||||
`--long_sequence_inference` command line option
|
||||
|
||||
#### SoloSeq Inference
|
||||
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
|
||||
|
||||
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
|
||||
|
||||
```bash
|
||||
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
|
||||
```
|
||||
|
||||
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
|
||||
|
||||
Now, you are ready to run inference:
|
||||
```bash
|
||||
python run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--use_precomputed_alignments embeddings_output_dir \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt
|
||||
```
|
||||
|
||||
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
```
|
||||
|
||||
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
|
||||
|
||||
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
|
||||
|
||||
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, you will first need to precompute protein alignments.
|
||||
|
@ -440,7 +467,7 @@ Please cite our paper:
|
|||
|
||||
```bibtex
|
||||
@article {Ahdritz2022.11.20.517210,
|
||||
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
|
||||
author = {Ahdritz, Gustaf and Bouatta, Nazim and Floristean, Christina and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
|
||||
title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization},
|
||||
elocation-id = {2022.11.20.517210},
|
||||
year = {2022},
|
||||
|
|
|
@ -1,33 +1,36 @@
|
|||
name: openfold_venv
|
||||
name: openfold-venv
|
||||
channels:
|
||||
- conda-forge
|
||||
- bioconda
|
||||
- pytorch
|
||||
dependencies:
|
||||
- conda-forge::python=3.9
|
||||
- conda-forge::setuptools=59.5.0
|
||||
- conda-forge::pip
|
||||
- conda-forge::openmm=7.5.1
|
||||
- conda-forge::pdbfixer
|
||||
- conda-forge::cudatoolkit==11.3.*
|
||||
- python=3.9
|
||||
- libgcc=7.2
|
||||
- setuptools=59.5.0
|
||||
- pip
|
||||
- openmm=7.7
|
||||
- pdbfixer
|
||||
- cudatoolkit==11.3.*
|
||||
- pytorch-lightning==1.5.10
|
||||
- biopython==1.79
|
||||
- numpy==1.21
|
||||
- PyYAML==5.4.1
|
||||
- requests
|
||||
- scipy==1.7
|
||||
- tqdm==4.62.2
|
||||
- typing-extensions==3.10
|
||||
- wandb==0.12.21
|
||||
- modelcif==0.7
|
||||
- awscli
|
||||
- ml-collections
|
||||
- aria2
|
||||
- git
|
||||
- bioconda::hmmer==3.3.2
|
||||
- bioconda::hhsuite==3.3.0
|
||||
- bioconda::kalign2==2.04
|
||||
- pytorch::pytorch=1.12.*
|
||||
- pip:
|
||||
- biopython==1.79
|
||||
- deepspeed==0.12.2
|
||||
- dm-tree==0.1.6
|
||||
- ml-collections==0.1.0
|
||||
- numpy==1.21.2
|
||||
- PyYAML==5.4.1
|
||||
- requests==2.26.0
|
||||
- scipy==1.7.1
|
||||
- tqdm==4.62.2
|
||||
- typing-extensions==3.10.0.2
|
||||
- pytorch_lightning==1.5.10
|
||||
- wandb==0.12.21
|
||||
- modelcif==0.7
|
||||
- git+https://github.com/NVIDIA/dllogger.git
|
||||
- git+https://github.com/microsoft/DeepSpeed.git
|
||||
# TODO: Replace above when version becomes available
|
||||
# - deepspeed==0.10.4
|
||||
- git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
Index: simtk/openmm/app/topology.py
|
||||
===================================================================
|
||||
--- simtk.orig/openmm/app/topology.py
|
||||
+++ simtk/openmm/app/topology.py
|
||||
@@ -356,19 +356,35 @@
|
||||
def isCyx(res):
|
||||
names = [atom.name for atom in res._atoms]
|
||||
return 'SG' in names and 'HG' not in names
|
||||
+ # This function is used to prevent multiple di-sulfide bonds from being
|
||||
+ # assigned to a given atom. This is a DeepMind modification.
|
||||
+ def isDisulfideBonded(atom):
|
||||
+ for b in self._bonds:
|
||||
+ if (atom in b and b[0].name == 'SG' and
|
||||
+ b[1].name == 'SG'):
|
||||
+ return True
|
||||
+
|
||||
+ return False
|
||||
|
||||
cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)]
|
||||
atomNames = [[atom.name for atom in res._atoms] for res in cyx]
|
||||
for i in range(len(cyx)):
|
||||
sg1 = cyx[i]._atoms[atomNames[i].index('SG')]
|
||||
pos1 = positions[sg1.index]
|
||||
+ candidate_distance, candidate_atom = 0.3*nanometers, None
|
||||
for j in range(i):
|
||||
sg2 = cyx[j]._atoms[atomNames[j].index('SG')]
|
||||
pos2 = positions[sg2.index]
|
||||
delta = [x-y for (x,y) in zip(pos1, pos2)]
|
||||
distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2])
|
||||
- if distance < 0.3*nanometers:
|
||||
- self.addBond(sg1, sg2)
|
||||
+ if distance < candidate_distance and not isDisulfideBonded(sg2):
|
||||
+ candidate_distance = distance
|
||||
+ candidate_atom = sg2
|
||||
+ # Assign bond to closest pair.
|
||||
+ if candidate_atom:
|
||||
+ self.addBond(sg1, candidate_atom)
|
||||
+
|
||||
+
|
||||
|
||||
class Chain(object):
|
||||
"""A Chain object represents a chain within a Topology."""
|
|
@ -152,9 +152,42 @@ def model_config(
|
|||
c.model.template.enabled = False
|
||||
c.model.heads.tm.enabled = True
|
||||
c.loss.tm.weight = 0.1
|
||||
# SINGLE SEQUENCE EMBEDDING PRESETS
|
||||
elif name == "seqemb_initial_training":
|
||||
c.data.train.max_msa_clusters = 1
|
||||
c.data.eval.max_msa_clusters = 1
|
||||
c.data.train.max_distillation_msa_clusters = 1
|
||||
elif name == "seqemb_finetuning":
|
||||
c.data.train.max_msa_clusters = 1
|
||||
c.data.eval.max_msa_clusters = 1
|
||||
c.data.train.max_distillation_msa_clusters = 1
|
||||
c.data.train.crop_size = 384
|
||||
c.loss.violation.weight = 1.
|
||||
c.loss.experimentally_resolved.weight = 0.01
|
||||
elif name == "seq_model_esm1b":
|
||||
c.data.common.use_templates = True
|
||||
c.data.common.use_template_torsion_angles = True
|
||||
c.model.template.enabled = True
|
||||
c.data.predict.max_msa_clusters = 1
|
||||
elif name == "seq_model_esm1b_ptm":
|
||||
c.data.common.use_templates = True
|
||||
c.data.common.use_template_torsion_angles = True
|
||||
c.model.template.enabled = True
|
||||
c.data.predict.max_msa_clusters = 1
|
||||
c.model.heads.tm.enabled = True
|
||||
c.loss.tm.weight = 0.1
|
||||
else:
|
||||
raise ValueError("Invalid model name")
|
||||
|
||||
if name.startswith("seq"):
|
||||
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
|
||||
c.data.seqemb_mode.enabled = True
|
||||
c.globals.seqemb_mode_enabled = True
|
||||
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
|
||||
c.model.extra_msa.enabled = False
|
||||
c.model.evoformer_stack.no_column_attention = True
|
||||
c.update(seq_mode_config.copy_and_resolve_references())
|
||||
|
||||
if long_sequence_inference:
|
||||
assert(not train)
|
||||
c.globals.offload_inference = True
|
||||
|
@ -189,6 +222,11 @@ c_m = mlc.FieldReference(256, field_type=int)
|
|||
c_t = mlc.FieldReference(64, field_type=int)
|
||||
c_e = mlc.FieldReference(64, field_type=int)
|
||||
c_s = mlc.FieldReference(384, field_type=int)
|
||||
|
||||
# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
|
||||
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
|
||||
preemb_dim_size = mlc.FieldReference(1280, field_type=int)
|
||||
|
||||
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
|
||||
chunk_size = mlc.FieldReference(4, field_type=int)
|
||||
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
|
||||
|
@ -301,6 +339,9 @@ config = mlc.ConfigDict(
|
|||
"use_templates": templates_enabled,
|
||||
"use_template_torsion_angles": embed_template_torsion_angles,
|
||||
},
|
||||
"seqemb_mode": { # Configuration for sequence embedding mode
|
||||
"enabled": False, # If True, use seq emb instead of MSA
|
||||
},
|
||||
"supervised": {
|
||||
"clamp_prob": 0.9,
|
||||
"supervised_features": [
|
||||
|
@ -365,6 +406,7 @@ config = mlc.ConfigDict(
|
|||
},
|
||||
# Recurring FieldReferences that can be changed globally here
|
||||
"globals": {
|
||||
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"chunk_size": chunk_size,
|
||||
# Use DeepSpeed memory-efficient attention kernel. Mutually
|
||||
|
@ -497,6 +539,7 @@ config = mlc.ConfigDict(
|
|||
"transition_n": 4,
|
||||
"msa_dropout": 0.15,
|
||||
"pair_dropout": 0.25,
|
||||
"no_column_attention": False,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"clear_cache_between_blocks": False,
|
||||
"tune_chunk_size": tune_chunk_size,
|
||||
|
@ -618,3 +661,31 @@ config = mlc.ConfigDict(
|
|||
"ema": {"decay": 0.999},
|
||||
}
|
||||
)
|
||||
|
||||
seq_mode_config = mlc.ConfigDict({
|
||||
"data": {
|
||||
"common": {
|
||||
"feat": {
|
||||
"seq_embedding": [NUM_RES, None],
|
||||
},
|
||||
"seqemb_features": [ # List of features to be generated in seqemb mode
|
||||
"seq_embedding"
|
||||
],
|
||||
},
|
||||
"seqemb_mode": { # Configuration for sequence embedding mode
|
||||
"enabled": True, # If True, use seq emb instead of MSA
|
||||
},
|
||||
},
|
||||
"globals": {
|
||||
"seqemb_mode_enabled": True,
|
||||
},
|
||||
"model": {
|
||||
"preembedding_embedder": { # Used in sequence embedding mode
|
||||
"tf_dim": 22,
|
||||
"preembedding_dim": preemb_dim_size,
|
||||
"c_z": c_z,
|
||||
"c_m": c_m,
|
||||
"relpos_k": 32,
|
||||
},
|
||||
}
|
||||
})
|
|
@ -186,7 +186,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
|||
mmcif=mmcif_object,
|
||||
alignment_dir=alignment_dir,
|
||||
chain_id=chain_id,
|
||||
alignment_index=alignment_index
|
||||
alignment_index=alignment_index,
|
||||
seqemb_mode=self.config.seqemb_mode.enabled
|
||||
)
|
||||
|
||||
return data
|
||||
|
@ -239,6 +240,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
|||
elif(ext == ".core"):
|
||||
data = self.data_pipeline.process_core(
|
||||
path, alignment_dir, alignment_index,
|
||||
seqemb_mode=self.config.seqemb_mode.enabled,
|
||||
)
|
||||
elif(ext == ".pdb"):
|
||||
structure_index = None
|
||||
|
@ -251,6 +253,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
|||
chain_id=chain_id,
|
||||
alignment_index=alignment_index,
|
||||
_structure_index=structure_index,
|
||||
seqemb_mode=self.config.seqemb_mode.enabled,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Extension branch missing")
|
||||
|
@ -260,6 +263,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
|||
fasta_path=path,
|
||||
alignment_dir=alignment_dir,
|
||||
alignment_index=alignment_index,
|
||||
seqemb_mode=self.config.seqemb_mode.enabled,
|
||||
)
|
||||
|
||||
if(self._output_raw):
|
||||
|
|
|
@ -19,6 +19,7 @@ from multiprocessing import cpu_count
|
|||
from typing import Mapping, Optional, Sequence, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from openfold.data import templates, parsers, mmcif_parsing
|
||||
from openfold.data.templates import get_custom_template_features
|
||||
|
@ -260,6 +261,18 @@ def make_msa_features(
|
|||
return features
|
||||
|
||||
|
||||
# Generate 1-sequence MSA features having only the input sequence
|
||||
def make_dummy_msa_feats(input_sequence):
|
||||
msas = [[input_sequence]]
|
||||
deletion_matrices = [[[0 for _ in input_sequence]]]
|
||||
msa_features = make_msa_features(
|
||||
msas=msas,
|
||||
deletion_matrices=deletion_matrices,
|
||||
)
|
||||
|
||||
return msa_features
|
||||
|
||||
|
||||
def make_sequence_features_with_custom_template(
|
||||
sequence: str,
|
||||
mmcif_path: str,
|
||||
|
@ -627,11 +640,28 @@ class DataPipeline:
|
|||
|
||||
return msa_features
|
||||
|
||||
# Load and process sequence embedding features
|
||||
def _process_seqemb_features(self,
|
||||
alignment_dir: str,
|
||||
) -> Mapping[str, Any]:
|
||||
seqemb_features = {}
|
||||
for f in os.listdir(alignment_dir):
|
||||
path = os.path.join(alignment_dir, f)
|
||||
ext = os.path.splitext(f)[-1]
|
||||
|
||||
if (ext == ".pt"):
|
||||
# Load embedding file
|
||||
seqemb_data = torch.load(path)
|
||||
seqemb_features["seq_embedding"] = seqemb_data["representations"][33]
|
||||
|
||||
return seqemb_features
|
||||
|
||||
def process_fasta(
|
||||
self,
|
||||
fasta_path: str,
|
||||
alignment_dir: str,
|
||||
alignment_index: Optional[str] = None,
|
||||
seqemb_mode: bool = False,
|
||||
) -> FeatureDict:
|
||||
"""Assembles features for a single sequence in a FASTA file"""
|
||||
with open(fasta_path) as f:
|
||||
|
@ -658,12 +688,19 @@ class DataPipeline:
|
|||
num_res=num_res,
|
||||
)
|
||||
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
sequence_embedding_features = {}
|
||||
# If using seqemb mode, generate a dummy MSA features using just the sequence
|
||||
if seqemb_mode:
|
||||
msa_features = make_dummy_msa_feats(input_sequence)
|
||||
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
|
||||
else:
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
|
||||
return {
|
||||
**sequence_features,
|
||||
**msa_features,
|
||||
**template_features
|
||||
**template_features,
|
||||
**sequence_embedding_features,
|
||||
}
|
||||
|
||||
def process_mmcif(
|
||||
|
@ -672,6 +709,7 @@ class DataPipeline:
|
|||
alignment_dir: str,
|
||||
chain_id: Optional[str] = None,
|
||||
alignment_index: Optional[str] = None,
|
||||
seqemb_mode: bool = False,
|
||||
) -> FeatureDict:
|
||||
"""
|
||||
Assembles features for a specific chain in an mmCIF object.
|
||||
|
@ -696,10 +734,16 @@ class DataPipeline:
|
|||
self.template_featurizer,
|
||||
query_release_date=to_date(mmcif.header["release_date"])
|
||||
)
|
||||
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
|
||||
return {**mmcif_feats, **template_features, **msa_features}
|
||||
sequence_embedding_features = {}
|
||||
# If using seqemb mode, generate a dummy MSA features using just the sequence
|
||||
if seqemb_mode:
|
||||
msa_features = make_dummy_msa_feats(input_sequence)
|
||||
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
|
||||
else:
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
|
||||
return {**mmcif_feats, **template_features, **msa_features, **sequence_embedding_features}
|
||||
|
||||
def process_pdb(
|
||||
self,
|
||||
|
@ -709,6 +753,7 @@ class DataPipeline:
|
|||
chain_id: Optional[str] = None,
|
||||
_structure_index: Optional[str] = None,
|
||||
alignment_index: Optional[str] = None,
|
||||
seqemb_mode: bool = False,
|
||||
) -> FeatureDict:
|
||||
"""
|
||||
Assembles features for a protein in a PDB file.
|
||||
|
@ -742,15 +787,22 @@ class DataPipeline:
|
|||
self.template_featurizer,
|
||||
)
|
||||
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
sequence_embedding_features = {}
|
||||
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
|
||||
if seqemb_mode:
|
||||
msa_features = make_dummy_msa_feats(input_sequence)
|
||||
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
|
||||
else:
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
|
||||
return {**pdb_feats, **template_features, **msa_features}
|
||||
return {**pdb_feats, **template_features, **msa_features, **sequence_embedding_features}
|
||||
|
||||
def process_core(
|
||||
self,
|
||||
core_path: str,
|
||||
alignment_dir: str,
|
||||
alignment_index: Optional[str] = None,
|
||||
seqemb_mode: bool = False,
|
||||
) -> FeatureDict:
|
||||
"""
|
||||
Assembles features for a protein in a ProteinNet .core file.
|
||||
|
@ -770,9 +822,15 @@ class DataPipeline:
|
|||
self.template_featurizer,
|
||||
)
|
||||
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
|
||||
sequence_embedding_features = {}
|
||||
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
|
||||
if seqemb_mode:
|
||||
msa_features = make_dummy_msa_feats(input_sequence)
|
||||
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
|
||||
else:
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
|
||||
|
||||
return {**core_feats, **template_features, **msa_features}
|
||||
return {**core_feats, **template_features, **msa_features, **sequence_embedding_features}
|
||||
|
||||
def process_multiseq_fasta(self,
|
||||
fasta_path: str,
|
||||
|
|
|
@ -40,9 +40,11 @@ def np_to_tensor_dict(
|
|||
Returns:
|
||||
A dictionary of features mapping feature names to features. Only the given
|
||||
features are returned, all other ones are filtered out.
|
||||
"""
|
||||
"""
|
||||
# torch generates warnings if feature is already a torch Tensor
|
||||
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
|
||||
tensor_dict = {
|
||||
k: torch.tensor(v) for k, v in np_example.items() if k in features
|
||||
k: to_tensor(v) for k, v in np_example.items() if k in features
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
|
@ -61,6 +63,10 @@ def make_data_config(
|
|||
|
||||
feature_names = cfg.common.unsupervised_features
|
||||
|
||||
# Add seqemb related features if using seqemb mode.
|
||||
if cfg.seqemb_mode.enabled:
|
||||
feature_names += cfg.common.seqemb_features
|
||||
|
||||
if cfg.common.use_templates:
|
||||
feature_names += cfg.common.template_features
|
||||
|
||||
|
|
|
@ -139,6 +139,100 @@ class InputEmbedder(nn.Module):
|
|||
return msa_emb, pair_emb
|
||||
|
||||
|
||||
class PreembeddingEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds the sequence pre-embedding passed to the model and the target_feat features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tf_dim: int,
|
||||
preembedding_dim: int,
|
||||
c_z: int,
|
||||
c_m: int,
|
||||
relpos_k: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
tf_dim:
|
||||
End channel dimension of the incoming target features
|
||||
preembedding_dim:
|
||||
End channel dimension of the incoming embeddings
|
||||
c_z:
|
||||
Pair embedding dimension
|
||||
c_m:
|
||||
Single-Seq embedding dimension
|
||||
relpos_k:
|
||||
Window size used in relative position encoding
|
||||
"""
|
||||
super(PreembeddingEmbedder, self).__init__()
|
||||
|
||||
self.tf_dim = tf_dim
|
||||
self.preembedding_dim = preembedding_dim
|
||||
|
||||
self.c_z = c_z
|
||||
self.c_m = c_m
|
||||
|
||||
self.linear_tf_m = Linear(tf_dim, c_m)
|
||||
self.linear_preemb_m = Linear(self.preembedding_dim, c_m)
|
||||
self.linear_preemb_z_i = Linear(self.preembedding_dim, c_z)
|
||||
self.linear_preemb_z_j = Linear(self.preembedding_dim, c_z)
|
||||
|
||||
# Relative Positional Encoding
|
||||
self.relpos_k = relpos_k
|
||||
self.no_bins = 2 * relpos_k + 1
|
||||
self.linear_relpos = Linear(self.no_bins, c_z)
|
||||
|
||||
def relpos(self, ri: torch.Tensor):
|
||||
"""
|
||||
Computes relative positional encodings
|
||||
Args:
|
||||
ri:
|
||||
"residue_index" feature of shape [*, N]
|
||||
Returns:
|
||||
Relative positional encoding of protein using the
|
||||
residue_index feature
|
||||
"""
|
||||
d = ri[..., None] - ri[..., None, :]
|
||||
boundaries = torch.arange(
|
||||
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
|
||||
)
|
||||
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
|
||||
d = d[..., None] - reshaped_bins
|
||||
d = torch.abs(d)
|
||||
d = torch.argmin(d, dim=-1)
|
||||
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
|
||||
d = d.to(ri.dtype)
|
||||
return self.linear_relpos(d)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tf: torch.Tensor,
|
||||
ri: torch.Tensor,
|
||||
preemb: torch.Tensor,
|
||||
inplace_safe: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
tf_m = (
|
||||
self.linear_tf_m(tf)
|
||||
.unsqueeze(-3)
|
||||
)
|
||||
preemb_emb = self.linear_preemb_m(preemb[..., None, :, :]) + tf_m
|
||||
preemb_emb_i = self.linear_preemb_z_i(preemb)
|
||||
preemb_emb_j = self.linear_preemb_z_j(preemb)
|
||||
|
||||
pair_emb = self.relpos(ri.type(preemb_emb_i.dtype))
|
||||
pair_emb = add(pair_emb,
|
||||
preemb_emb_i[..., None, :],
|
||||
inplace=inplace_safe)
|
||||
pair_emb = add(pair_emb,
|
||||
preemb_emb_j[..., None, :, :],
|
||||
inplace=inplace_safe)
|
||||
|
||||
return preemb_emb, pair_emb
|
||||
|
||||
|
||||
class RecyclingEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds the output of an iteration of the model for recycling.
|
||||
|
|
|
@ -87,7 +87,6 @@ class MSATransition(nn.Module):
|
|||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
|
@ -326,6 +325,7 @@ class EvoformerBlock(nn.Module):
|
|||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
no_column_attention: bool,
|
||||
inf: float,
|
||||
eps: float,
|
||||
):
|
||||
|
@ -339,12 +339,15 @@ class EvoformerBlock(nn.Module):
|
|||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
c_m,
|
||||
c_hidden_msa_att,
|
||||
no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
# Specifically, seqemb mode does not use column attention
|
||||
self.no_column_attention = no_column_attention
|
||||
if not self.no_column_attention:
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
c_m,
|
||||
c_hidden_msa_att,
|
||||
no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||
|
||||
|
@ -402,18 +405,20 @@ class EvoformerBlock(nn.Module):
|
|||
),
|
||||
inplace=inplace_safe,
|
||||
)
|
||||
|
||||
m = add(m,
|
||||
self.msa_att_col(
|
||||
m,
|
||||
mask=msa_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
),
|
||||
inplace=inplace_safe,
|
||||
)
|
||||
|
||||
# Specifically, column attention is not used in seqemb mode.
|
||||
if not self.no_column_attention:
|
||||
m = add(m,
|
||||
self.msa_att_col(
|
||||
m,
|
||||
mask=msa_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
),
|
||||
inplace=inplace_safe,
|
||||
)
|
||||
|
||||
if(not inplace_safe):
|
||||
input_tensors = [m, input_tensors[1]]
|
||||
|
@ -605,6 +610,7 @@ class EvoformerStack(nn.Module):
|
|||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
blocks_per_ckpt: int,
|
||||
no_column_attention: bool,
|
||||
inf: float,
|
||||
eps: float,
|
||||
clear_cache_between_blocks: bool = False,
|
||||
|
@ -642,6 +648,9 @@ class EvoformerStack(nn.Module):
|
|||
Dropout used for pair activations
|
||||
blocks_per_ckpt:
|
||||
Number of Evoformer blocks in each activation checkpoint
|
||||
no_column_attention:
|
||||
When True, doesn't use column attention. Required for running
|
||||
sequence embedding mode
|
||||
clear_cache_between_blocks:
|
||||
Whether to clear CUDA's GPU memory cache between blocks of the
|
||||
stack. Slows down each block but can reduce fragmentation
|
||||
|
@ -668,6 +677,7 @@ class EvoformerStack(nn.Module):
|
|||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
no_column_attention=no_column_attention,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ from openfold.model.embedders import (
|
|||
TemplateAngleEmbedder,
|
||||
TemplatePairEmbedder,
|
||||
ExtraMSAEmbedder,
|
||||
PreembeddingEmbedder,
|
||||
)
|
||||
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
|
||||
from openfold.model.heads import AuxiliaryHeads
|
||||
|
@ -71,11 +72,19 @@ class AlphaFold(nn.Module):
|
|||
self.config = config.model
|
||||
self.template_config = self.config.template
|
||||
self.extra_msa_config = self.config.extra_msa
|
||||
self.seqemb_mode = config.globals.seqemb_mode_enabled
|
||||
|
||||
# Main trunk + structure module
|
||||
self.input_embedder = InputEmbedder(
|
||||
**self.config["input_embedder"],
|
||||
)
|
||||
# If using seqemb mode, embed the sequence embeddings passed
|
||||
# to the model ("preembeddings") instead of embedding the sequence
|
||||
if self.seqemb_mode:
|
||||
self.input_embedder = PreembeddingEmbedder(
|
||||
**self.config["preembedding_embedder"],
|
||||
)
|
||||
else:
|
||||
self.input_embedder = InputEmbedder(
|
||||
**self.config["input_embedder"],
|
||||
)
|
||||
self.recycling_embedder = RecyclingEmbedder(
|
||||
**self.config["recycling_embedder"],
|
||||
)
|
||||
|
@ -238,17 +247,27 @@ class AlphaFold(nn.Module):
|
|||
seq_mask = feats["seq_mask"]
|
||||
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
|
||||
msa_mask = feats["msa_mask"]
|
||||
|
||||
## Initialize the MSA and pair representations
|
||||
|
||||
# m: [*, S_c, N, C_m]
|
||||
## Initialize the SingleSeq and pair representations
|
||||
# m: [*, 1, N, C_m]
|
||||
# z: [*, N, N, C_z]
|
||||
m, z = self.input_embedder(
|
||||
feats["target_feat"],
|
||||
feats["residue_index"],
|
||||
feats["msa_feat"],
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
if self.seqemb_mode:
|
||||
m, z = self.input_embedder(
|
||||
feats["target_feat"],
|
||||
feats["residue_index"],
|
||||
feats["seq_embedding"]
|
||||
)
|
||||
|
||||
else:
|
||||
## Initialize the MSA and pair representations
|
||||
# m: [*, S_c, N, C_m]
|
||||
# z: [*, N, N, C_z]
|
||||
m, z = self.input_embedder(
|
||||
feats["target_feat"],
|
||||
feats["residue_index"],
|
||||
feats["msa_feat"],
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
|
||||
# Unpack the recycling embeddings. Removing them from the list allows
|
||||
# them to be freed further down in this function, saving memory
|
||||
|
|
|
@ -23,7 +23,7 @@ if deepspeed_is_installed:
|
|||
import deepspeed
|
||||
|
||||
if ds4s_is_installed:
|
||||
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
|
||||
from deepspeed.ops.deepspeed4science import EvoformerFusedAttention
|
||||
|
||||
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
|
||||
if fa_is_installed:
|
||||
|
@ -661,18 +661,19 @@ def _deepspeed_evo_attn(
|
|||
v = reshape_dims(v)
|
||||
biases = [reshape_dims(b) for b in biases]
|
||||
|
||||
biases.extend([None] * (2 - len(biases)))
|
||||
|
||||
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
|
||||
# Cast to bf16 so kernel can be used during inference
|
||||
orig_dtype = q.dtype
|
||||
if orig_dtype not in [torch.bfloat16, torch.float16]:
|
||||
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
|
||||
k.to(dtype=torch.bfloat16),
|
||||
v.to(dtype=torch.bfloat16),
|
||||
[b.to(dtype=torch.bfloat16) for b in biases])
|
||||
inputs_bf16 = [x.to(dtype=torch.bfloat16) if x is not None else x
|
||||
for x in (q, k, v, biases[0], biases[1])]
|
||||
o = EvoformerFusedAttention.apply(*inputs_bf16)
|
||||
|
||||
o = o.to(dtype=orig_dtype)
|
||||
else:
|
||||
o = DS4Sci_EvoformerAttention(q, k, v, biases)
|
||||
o = EvoformerFusedAttention.apply(q, k, v, biases[0], biases[1])
|
||||
|
||||
o = o.reshape(orig_shape)
|
||||
return o
|
||||
|
|
|
@ -28,18 +28,10 @@ import openfold.utils.loss as loss
|
|||
from openfold.np.relax import cleanup, utils
|
||||
import ml_collections
|
||||
import numpy as np
|
||||
try:
|
||||
# openmm >= 7.6
|
||||
import openmm
|
||||
from openmm import unit
|
||||
from openmm import app as openmm_app
|
||||
from openmm.app.internal.pdbstructure import PdbStructure
|
||||
except ImportError:
|
||||
# openmm < 7.6 (requires DeepMind patch)
|
||||
from simtk import openmm
|
||||
from simtk import unit
|
||||
from simtk.openmm import app as openmm_app
|
||||
from simtk.openmm.app.internal.pdbstructure import PdbStructure
|
||||
import openmm
|
||||
from openmm import unit
|
||||
from openmm import app as openmm_app
|
||||
from openmm.app.internal.pdbstructure import PdbStructure
|
||||
|
||||
ENERGY = unit.kilocalories_per_mole
|
||||
LENGTH = unit.angstroms
|
||||
|
|
|
@ -20,14 +20,8 @@ cases like removing chains of length one (see clean_structure).
|
|||
import io
|
||||
|
||||
import pdbfixer
|
||||
try:
|
||||
# openmm >= 7.6
|
||||
from openmm import app
|
||||
from openmm.app import element
|
||||
except ImportError:
|
||||
# openmm < 7.6 (requires DeepMind patch)
|
||||
from simtk.openmm import app
|
||||
from simtk.openmm.app import element
|
||||
from openmm import app
|
||||
from openmm.app import element
|
||||
|
||||
|
||||
def fix_pdb(pdbfile, alterations_info):
|
||||
|
|
|
@ -18,14 +18,8 @@ import io
|
|||
from openfold.np import residue_constants
|
||||
from Bio import PDB
|
||||
import numpy as np
|
||||
try:
|
||||
# openmm >= 7.6
|
||||
from openmm import app as openmm_app
|
||||
from openmm.app.internal.pdbstructure import PdbStructure
|
||||
except ImportError:
|
||||
# openmm < 7.6 (requires DeepMind patch)
|
||||
from simtk.openmm import app as openmm_app
|
||||
from simtk.openmm.app.internal.pdbstructure import PdbStructure
|
||||
from openmm import app as openmm_app
|
||||
from openmm.app.internal.pdbstructure import PdbStructure
|
||||
|
||||
|
||||
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
|
||||
|
|
|
@ -159,7 +159,7 @@ def run_model(model, batch, tag, output_dir):
|
|||
out = model(batch)
|
||||
inference_time = time.perf_counter() - t
|
||||
logger.info(f"Inference time: {inference_time}")
|
||||
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json"))
|
||||
update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))
|
||||
|
||||
model.config.template.enabled = template_enabled
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
|
|||
pad_feature_dict_seq,
|
||||
trace_model_,
|
||||
)
|
||||
from scripts.precompute_embeddings import EmbeddingGenerator
|
||||
from scripts.utils import add_data_args
|
||||
|
||||
|
||||
|
@ -73,17 +74,29 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
|
|||
|
||||
os.makedirs(local_alignment_dir)
|
||||
|
||||
alignment_runner = data_pipeline.AlignmentRunner(
|
||||
jackhmmer_binary_path=args.jackhmmer_binary_path,
|
||||
hhblits_binary_path=args.hhblits_binary_path,
|
||||
hhsearch_binary_path=args.hhsearch_binary_path,
|
||||
uniref90_database_path=args.uniref90_database_path,
|
||||
mgnify_database_path=args.mgnify_database_path,
|
||||
bfd_database_path=args.bfd_database_path,
|
||||
uniclust30_database_path=args.uniclust30_database_path,
|
||||
pdb70_database_path=args.pdb70_database_path,
|
||||
no_cpus=args.cpus,
|
||||
)
|
||||
# In seqemb mode, use AlignmentRunner only to generate templates
|
||||
if args.use_single_seq_mode:
|
||||
alignment_runner = data_pipeline.AlignmentRunner(
|
||||
jackhmmer_binary_path=args.jackhmmer_binary_path,
|
||||
hhsearch_binary_path=args.hhsearch_binary_path,
|
||||
uniref90_database_path=args.uniref90_database_path,
|
||||
pdb70_database_path=args.pdb70_database_path,
|
||||
no_cpus=args.cpus,
|
||||
)
|
||||
embedding_generator = EmbeddingGenerator()
|
||||
embedding_generator.run(tmp_fasta_path, alignment_dir)
|
||||
else:
|
||||
alignment_runner = data_pipeline.AlignmentRunner(
|
||||
jackhmmer_binary_path=args.jackhmmer_binary_path,
|
||||
hhblits_binary_path=args.hhblits_binary_path,
|
||||
hhsearch_binary_path=args.hhsearch_binary_path,
|
||||
uniref90_database_path=args.uniref90_database_path,
|
||||
mgnify_database_path=args.mgnify_database_path,
|
||||
bfd_database_path=args.bfd_database_path,
|
||||
uniclust30_database_path=args.uniclust30_database_path,
|
||||
pdb70_database_path=args.pdb70_database_path,
|
||||
no_cpus=args.cpus,
|
||||
)
|
||||
alignment_runner.run(
|
||||
tmp_fasta_path, local_alignment_dir
|
||||
)
|
||||
|
@ -116,7 +129,9 @@ def generate_feature_dict(
|
|||
|
||||
local_alignment_dir = os.path.join(alignment_dir, tag)
|
||||
feature_dict = data_processor.process_fasta(
|
||||
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
|
||||
fasta_path=tmp_fasta_path,
|
||||
alignment_dir=local_alignment_dir,
|
||||
seqemb_mode=args.use_single_seq_mode,
|
||||
)
|
||||
else:
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
|
@ -140,6 +155,8 @@ def main(args):
|
|||
# Create the output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.config_preset.startswith("seq"):
|
||||
args.use_single_seq_mode = True
|
||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||
|
||||
if(args.trace_model):
|
||||
|
@ -314,6 +331,10 @@ if __name__ == "__main__":
|
|||
help="""Path to alignment directory. If provided, alignment computation
|
||||
is skipped and database path arguments are ignored."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_single_seq_mode", action="store_true", default=False,
|
||||
help="""Use single sequence embeddings instead of MSAs."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default=os.getcwd(),
|
||||
help="""Name of the directory in which to output the prediction""",
|
||||
|
|
|
@ -1,54 +1,26 @@
|
|||
#!/bin/bash
|
||||
CONDA_INSTALL_URL=${CONDA_INSTALL_URL:-"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh"}
|
||||
|
||||
source scripts/vars.sh
|
||||
|
||||
# Install Miniconda locally
|
||||
rm -rf lib/conda
|
||||
rm -f /tmp/Miniconda3-latest-Linux-x86_64.sh
|
||||
wget -P /tmp \
|
||||
"${CONDA_INSTALL_URL}" \
|
||||
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p lib/conda \
|
||||
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh
|
||||
|
||||
# Grab conda-only packages
|
||||
export PATH=lib/conda/bin:$PATH
|
||||
lib/conda/bin/python3 -m pip install nvidia-pyindex
|
||||
conda env create --name=${ENV_NAME} -f environment.yml
|
||||
source scripts/activate_conda_env.sh
|
||||
|
||||
echo "Attempting to install FlashAttention"
|
||||
git clone https://github.com/HazyResearch/flash-attention
|
||||
CUR_DIR=$PWD
|
||||
cd flash-attention
|
||||
git checkout 5b838a8bef
|
||||
python3 setup.py install
|
||||
cd $CUR_DIR
|
||||
|
||||
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
|
||||
git clone https://github.com/NVIDIA/cutlass.git
|
||||
conda env config vars set CUTLASS_PATH=$PWD/cutlass
|
||||
source scripts/activate_conda_env.sh
|
||||
|
||||
# Install DeepMind's OpenMM patch
|
||||
OPENFOLD_DIR=$PWD
|
||||
pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \
|
||||
&& patch -p0 < $OPENFOLD_DIR/lib/openmm.patch \
|
||||
&& popd
|
||||
|
||||
# Download folding resources
|
||||
wget --no-check-certificate -P openfold/resources \
|
||||
wget -N --no-check-certificate -P openfold/resources \
|
||||
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
|
||||
|
||||
# Certain tests need access to this file
|
||||
mkdir -p tests/test_data/alphafold/common
|
||||
ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
|
||||
|
||||
echo "Downloading OpenFold parameters..."
|
||||
bash scripts/download_openfold_params.sh openfold/resources
|
||||
|
||||
echo "Downloading AlphaFold parameters..."
|
||||
bash scripts/download_alphafold_params.sh openfold/resources
|
||||
|
||||
# Decompress test data
|
||||
gunzip tests/test_data/sample_feats.pickle.gz
|
||||
gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
|
||||
|
||||
python setup.py install
|
||||
|
||||
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
|
||||
|
||||
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
|
||||
git clone https://github.com/NVIDIA/cutlass --depth 1
|
||||
conda env config vars set CUTLASS_PATH=$PWD/cutlass
|
||||
|
||||
# This setting is used to fix a worker assignment issue during data loading
|
||||
conda env config vars set KMP_AFFINITY=none
|
||||
|
||||
# Reactivate env so that the above environment variables take effect
|
||||
conda activate $CONDA_PREFIX
|
||||
|
|
|
@ -0,0 +1,200 @@
|
|||
# Some functions borrowed from [ESM](https://www.github.com/facebookresearch/esm)
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from openfold.data import parsers
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
class SequenceDataset(object):
|
||||
def __init__(self, labels, sequences) -> None:
|
||||
self.labels = labels
|
||||
self.sequences = sequences
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, fasta_file):
|
||||
labels, sequences = [], []
|
||||
|
||||
with open(fasta_file, "r") as infile:
|
||||
fasta_str = infile.read()
|
||||
sequences, labels = parsers.parse_fasta(fasta_str)
|
||||
|
||||
assert len(set(labels)) == len(labels),\
|
||||
"Sequence labels need to be unique. Duplicates found!"
|
||||
|
||||
return cls(labels, sequences)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.labels[idx], self.sequences[idx]
|
||||
|
||||
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq):
|
||||
sizes = [(len(s), i) for i, s in enumerate(self.sequences)]
|
||||
sizes.sort()
|
||||
batches = []
|
||||
buf = []
|
||||
max_len = 0
|
||||
|
||||
def _flush_current_buf():
|
||||
nonlocal max_len, buf
|
||||
if len(buf) == 0:
|
||||
return
|
||||
batches.append(buf)
|
||||
buf = []
|
||||
max_len = 0
|
||||
|
||||
for sz, i in sizes:
|
||||
sz += extra_toks_per_seq
|
||||
if max(sz, max_len) * (len(buf)+1) > toks_per_batch:
|
||||
_flush_current_buf()
|
||||
max_len = max(max_len, sz)
|
||||
buf.append(i)
|
||||
|
||||
_flush_current_buf()
|
||||
return batches
|
||||
|
||||
|
||||
class EmbeddingGenerator:
|
||||
"""Generates the ESM-1b embeddings for the single sequence model"""
|
||||
def __init__(self,
|
||||
toks_per_batch: int = 4096,
|
||||
truncate: bool = True,
|
||||
use_local_esm: str = None,
|
||||
nogpu: bool = False,
|
||||
):
|
||||
self.toks_per_batch = toks_per_batch
|
||||
self.truncate = truncate
|
||||
self.use_local_esm = use_local_esm
|
||||
self.nogpu = nogpu
|
||||
|
||||
# Generate embeddings in bulk
|
||||
if self.use_local_esm:
|
||||
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
|
||||
else:
|
||||
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
|
||||
if torch.cuda.is_available() and not self.nogpu:
|
||||
self.model = self.model.to(device="cuda")
|
||||
|
||||
def parse_sequences(self, fasta_dir, output_dir):
|
||||
labels = []
|
||||
seqs = []
|
||||
|
||||
# Generate a single bulk file
|
||||
for f in os.listdir(fasta_dir):
|
||||
f_name, ext = os.path.splitext(f)
|
||||
if ext != '.fasta' and ext != '.fa':
|
||||
logging.warning(f"Ignoring non-FASTA file: {f}")
|
||||
continue
|
||||
with open(os.path.join(fasta_dir, f), 'r') as infile:
|
||||
seq = infile.readlines()[1].strip()
|
||||
labels.append(f_name)
|
||||
seqs.append(seq)
|
||||
|
||||
lines = []
|
||||
for label, seq in zip(labels, seqs):
|
||||
lines += f'>{label}\n'
|
||||
lines += f'{seq}\n'
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
|
||||
with open(temp_fasta_file, 'w') as outfile:
|
||||
outfile.writelines(lines)
|
||||
return temp_fasta_file
|
||||
|
||||
def run(
|
||||
self,
|
||||
fasta_file,
|
||||
output_dir,
|
||||
):
|
||||
|
||||
dataset = SequenceDataset.from_file(fasta_file)
|
||||
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
|
||||
)
|
||||
logging.info("Loaded all sequences")
|
||||
repr_layers = [33]
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
|
||||
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
|
||||
if torch.cuda.is_available() and not self.nogpu:
|
||||
toks = toks.to(device="cuda", non_blocking=True)
|
||||
|
||||
if self.truncate:
|
||||
toks = toks[:1022]
|
||||
|
||||
out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
|
||||
|
||||
representations = {
|
||||
33: out["representations"][33].to(device="cpu")
|
||||
}
|
||||
|
||||
for i, label in enumerate(labels):
|
||||
os.makedirs(os.path.join(output_dir, label), exist_ok=True)
|
||||
result = {"label": label}
|
||||
|
||||
result["representations"] = {
|
||||
33: representations[33][i, 1: len(strs[i]) + 1].clone()
|
||||
}
|
||||
torch.save(
|
||||
result,
|
||||
os.path.join(output_dir, label, label+".pt")
|
||||
)
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
logging.info("Loading the model...")
|
||||
embedding_generator = EmbeddingGenerator(
|
||||
args.toks_per_batch,
|
||||
args.truncate,
|
||||
args.use_local_esm,
|
||||
args.nogpu)
|
||||
logging.info("Loading the sequences and running the inference...")
|
||||
temp_fasta_file = embedding_generator.parse_sequences(
|
||||
args.fasta_dir,
|
||||
args.output_dir
|
||||
)
|
||||
embedding_generator.run(
|
||||
temp_fasta_file,
|
||||
args.output_dir
|
||||
)
|
||||
os.remove(temp_fasta_file)
|
||||
logging.info("Completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"fasta_dir", type=str,
|
||||
help="""Path to directory containing FASTA files."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_dir", type=str,
|
||||
help="Directory in which to output embeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--toks_per_batch", type=int, default=4096,
|
||||
help="maximum tokens in a batch"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--truncate", action="store_true", default=True,
|
||||
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_local_esm", type=str, default=None,
|
||||
help="Use a local ESM repository instead of cloning from Github"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nogpu", action="store_true",
|
||||
help="Do not use GPU"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
|
@ -17,6 +17,7 @@ import numpy as np
|
|||
import unittest
|
||||
from openfold.model.embedders import (
|
||||
InputEmbedder,
|
||||
PreembeddingEmbedder,
|
||||
RecyclingEmbedder,
|
||||
TemplateAngleEmbedder,
|
||||
TemplatePairEmbedder,
|
||||
|
@ -46,6 +47,28 @@ class TestInputEmbedder(unittest.TestCase):
|
|||
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
|
||||
|
||||
|
||||
class TestPreembeddingEmbedder(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
tf_dim = 22
|
||||
preembedding_dim = 1280
|
||||
c_z = 4
|
||||
c_m = 6
|
||||
relpos_k = 10
|
||||
|
||||
batch_size = 4
|
||||
num_res = 20
|
||||
|
||||
tf = torch.rand((batch_size, num_res, tf_dim))
|
||||
ri = torch.rand((batch_size, num_res))
|
||||
preemb = torch.rand((batch_size, num_res, preembedding_dim))
|
||||
|
||||
pe = PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k)
|
||||
|
||||
seq_emb, pair_emb = pe(tf, ri, preemb)
|
||||
self.assertTrue(seq_emb.shape == (batch_size, 1, num_res, c_m))
|
||||
self.assertTrue(pair_emb.shape == (batch_size, num_res, num_res, c_z))
|
||||
|
||||
|
||||
class TestRecyclingEmbedder(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
batch_size = 2
|
||||
|
|
|
@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
|
|||
msa_dropout,
|
||||
pair_stack_dropout,
|
||||
blocks_per_ckpt=None,
|
||||
no_column_attention=False,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
).eval()
|
||||
|
@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase):
|
|||
self.assertTrue(z.shape == shape_z_before)
|
||||
self.assertTrue(s.shape == (batch_size, n_res, c_s))
|
||||
|
||||
def test_shape_without_column_attention(self):
|
||||
batch_size = consts.batch_size
|
||||
n_seq = consts.n_seq
|
||||
n_res = consts.n_res
|
||||
c_m = consts.c_m
|
||||
c_z = consts.c_z
|
||||
c_hidden_msa_att = 12
|
||||
c_hidden_opm = 17
|
||||
c_hidden_mul = 19
|
||||
c_hidden_pair_att = 14
|
||||
c_s = consts.c_s
|
||||
no_heads_msa = 3
|
||||
no_heads_pair = 7
|
||||
no_blocks = 2
|
||||
transition_n = 2
|
||||
msa_dropout = 0.15
|
||||
pair_stack_dropout = 0.25
|
||||
inf = 1e9
|
||||
eps = 1e-10
|
||||
|
||||
es = EvoformerStack(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_msa_att,
|
||||
c_hidden_opm,
|
||||
c_hidden_mul,
|
||||
c_hidden_pair_att,
|
||||
c_s,
|
||||
no_heads_msa,
|
||||
no_heads_pair,
|
||||
no_blocks,
|
||||
transition_n,
|
||||
msa_dropout,
|
||||
pair_stack_dropout,
|
||||
blocks_per_ckpt=None,
|
||||
no_column_attention=True,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
).eval()
|
||||
|
||||
m_init = torch.rand((batch_size, n_seq, n_res, c_m))
|
||||
z_init = torch.rand((batch_size, n_res, n_res, c_z))
|
||||
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
|
||||
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
|
||||
|
||||
shape_m_before = m_init.shape
|
||||
shape_z_before = z_init.shape
|
||||
|
||||
m, z, s = es(
|
||||
m_init, z_init, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
|
||||
)
|
||||
|
||||
self.assertTrue(m.shape == shape_m_before)
|
||||
self.assertTrue(z.shape == shape_z_before)
|
||||
self.assertTrue(s.shape == (batch_size, n_res, c_s))
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_ei(activations, masks):
|
||||
|
@ -206,7 +263,7 @@ class TestExtraMSAStack(unittest.TestCase):
|
|||
n_res,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
).float()
|
||||
pair_mask = torch.randint(
|
||||
0,
|
||||
2,
|
||||
|
@ -216,7 +273,7 @@ class TestExtraMSAStack(unittest.TestCase):
|
|||
n_res,
|
||||
),
|
||||
device="cuda",
|
||||
)
|
||||
).float()
|
||||
|
||||
shape_z_before = z.shape
|
||||
|
||||
|
|
|
@ -47,27 +47,27 @@ class TestModel(unittest.TestCase):
|
|||
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
|
||||
# deepspeed for this test
|
||||
|
||||
model = AlphaFold(c)
|
||||
model = AlphaFold(c).cuda()
|
||||
model.eval()
|
||||
|
||||
batch = {}
|
||||
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
|
||||
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
|
||||
batch["target_feat"] = nn.functional.one_hot(
|
||||
tf, c.model.input_embedder.tf_dim
|
||||
).float()
|
||||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
|
||||
batch["residue_index"] = torch.arange(n_res)
|
||||
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
|
||||
).float().cuda()
|
||||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
|
||||
batch["residue_index"] = torch.arange(n_res).cuda()
|
||||
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
|
||||
t_feats = random_template_feats(n_templ, n_res)
|
||||
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
|
||||
batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
|
||||
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
|
||||
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
|
||||
batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
|
||||
batch["msa_mask"] = torch.randint(
|
||||
low=0, high=2, size=(n_seq, n_res)
|
||||
).float()
|
||||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
|
||||
).float().cuda()
|
||||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
|
||||
batch.update(data_transforms.make_atom14_masks(batch))
|
||||
batch["no_recycling_iters"] = torch.tensor(2.)
|
||||
batch["no_recycling_iters"] = torch.tensor(2.).cuda()
|
||||
|
||||
add_recycling_dims = lambda t: (
|
||||
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
|
||||
|
@ -77,6 +77,46 @@ class TestModel(unittest.TestCase):
|
|||
with torch.no_grad():
|
||||
out = model(batch)
|
||||
|
||||
def test_dry_run_seqemb_mode(self):
|
||||
n_seq = 1
|
||||
n_templ = consts.n_templ
|
||||
n_res = consts.n_res
|
||||
msa_dim = 49
|
||||
|
||||
c = model_config("seq_model_esm1b")
|
||||
c.model.evoformer_stack.no_blocks = 2
|
||||
c.model.evoformer_stack.blocks_per_ckpt = None
|
||||
model = AlphaFold(c)
|
||||
model.to(torch.device('cuda'))
|
||||
model.eval()
|
||||
|
||||
batch = {}
|
||||
tf = torch.randint(c.model.preembedding_embedder.tf_dim - 1, size=(n_res,))
|
||||
batch["target_feat"] = nn.functional.one_hot(tf, c.model.preembedding_embedder.tf_dim).float()
|
||||
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
|
||||
batch["residue_index"] = torch.arange(n_res)
|
||||
batch["msa_feat"] = torch.rand((n_seq, n_res, msa_dim))
|
||||
batch["seq_embedding"] = torch.rand((n_res, c.model.preembedding_embedder.preembedding_dim))
|
||||
|
||||
t_feats = random_template_feats(n_templ, n_res)
|
||||
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
|
||||
|
||||
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
|
||||
batch.update(data_transforms.make_atom14_masks(batch))
|
||||
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
|
||||
|
||||
batch["no_recycling_iters"] = torch.tensor(2.)
|
||||
add_recycling_dims = lambda t: (
|
||||
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
|
||||
)
|
||||
batch = tensor_tree_map(add_recycling_dims, batch)
|
||||
|
||||
to_cuda_device = lambda t: t.to(torch.device("cuda"))
|
||||
batch = tensor_tree_map(to_cuda_device, batch)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(batch)
|
||||
|
||||
@compare_utils.skip_unless_alphafold_installed()
|
||||
def test_compare(self):
|
||||
def run_alphafold(batch):
|
||||
|
|
|
@ -416,6 +416,10 @@ if __name__ == "__main__":
|
|||
help='''Cutoff for all templates. In training mode, templates are also
|
||||
filtered by the release date of the target'''
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_single_seq_mode", type=str, default=False,
|
||||
help="Use single sequence embeddings instead of MSAs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distillation_data_dir", type=str, default=None,
|
||||
help="Directory containing training PDB files"
|
||||
|
|
Loading…
Reference in New Issue