140 lines
4.3 KiB
Python
140 lines
4.3 KiB
Python
import os
|
|
|
|
import importlib
|
|
import pkgutil
|
|
import sys
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from openfold.config import model_config
|
|
from openfold.model.model import AlphaFold
|
|
from openfold.utils.import_weights import import_jax_weights_
|
|
|
|
from tests.config import consts
|
|
|
|
# Give JAX some GPU memory discipline
|
|
# (by default it hogs 90% of GPU memory. This disables that behavior and also
|
|
# forces it to proactively free memory that it allocates)
|
|
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
|
|
os.environ["JAX_PLATFORM_NAME"] = "gpu"
|
|
|
|
|
|
def skip_unless_ds4s_installed():
|
|
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
|
|
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
|
|
"deepspeed.ops.deepspeed4science") is not None
|
|
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
|
|
|
|
|
|
def skip_unless_flash_attn_installed():
|
|
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
|
|
return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")
|
|
|
|
|
|
def alphafold_is_installed():
|
|
return importlib.util.find_spec("alphafold") is not None
|
|
|
|
|
|
def skip_unless_alphafold_installed():
|
|
return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")
|
|
|
|
|
|
def import_alphafold():
|
|
"""
|
|
If AlphaFold is installed using the provided setuptools script, this
|
|
is necessary to expose all of AlphaFold's precious insides
|
|
"""
|
|
if "alphafold" in sys.modules:
|
|
return sys.modules["alphafold"]
|
|
module = importlib.import_module("alphafold")
|
|
# Forcefully import alphafold's submodules
|
|
submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold."))
|
|
for submodule_info in submodules:
|
|
importlib.import_module(submodule_info.name)
|
|
sys.modules["alphafold"] = module
|
|
globals()["alphafold"] = module
|
|
|
|
return module
|
|
|
|
|
|
def get_alphafold_config():
|
|
config = alphafold.model.config.model_config(consts.model) # noqa
|
|
config.model.global_config.deterministic = True
|
|
return config
|
|
|
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
_param_path = os.path.join(dir_path, "..", f"openfold/resources/params/params_{consts.model}.npz")
|
|
_model = None
|
|
|
|
|
|
def get_global_pretrained_openfold():
|
|
global _model
|
|
if _model is None:
|
|
_model = AlphaFold(model_config(consts.model))
|
|
_model = _model.eval()
|
|
if not os.path.exists(_param_path):
|
|
raise FileNotFoundError(
|
|
"""Cannot load pretrained parameters. Make sure to run the
|
|
installation script before running tests."""
|
|
)
|
|
import_jax_weights_(_model, _param_path, version=consts.model)
|
|
_model = _model.cuda()
|
|
|
|
return _model
|
|
|
|
|
|
_orig_weights = None
|
|
|
|
|
|
def _get_orig_weights():
|
|
global _orig_weights
|
|
if _orig_weights is None:
|
|
_orig_weights = np.load(_param_path)
|
|
|
|
return _orig_weights
|
|
|
|
|
|
def _remove_key_prefix(d, prefix):
|
|
for k, v in list(d.items()):
|
|
if k.startswith(prefix):
|
|
d.pop(k)
|
|
d[k[len(prefix):]] = v
|
|
|
|
|
|
def fetch_alphafold_module_weights(weight_path):
|
|
orig_weights = _get_orig_weights()
|
|
params = {k: v for k, v in orig_weights.items() if weight_path in k}
|
|
if "/" in weight_path:
|
|
spl = weight_path.split("/")
|
|
spl = spl if len(spl[-1]) != 0 else spl[:-1]
|
|
prefix = "/".join(spl[:-1]) + "/"
|
|
_remove_key_prefix(params, prefix)
|
|
|
|
try:
|
|
params = alphafold.model.utils.flat_params_to_haiku(params) # noqa
|
|
except:
|
|
raise ImportError(
|
|
"Make sure to call import_alphafold before running this function"
|
|
)
|
|
return params
|
|
|
|
|
|
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
|
|
# Helper function for comparing absolute differences of two torch tensors.
|
|
abs_diff = torch.abs(expected - actual)
|
|
err = compare_func(abs_diff)
|
|
zero_tensor = torch.tensor(0, dtype=err.dtype)
|
|
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
|
|
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
|
|
|
|
|
|
def assert_max_abs_diff_small(expected, actual, eps):
|
|
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
|
|
|
|
|
|
def assert_mean_abs_diff_small(expected, actual, eps):
|
|
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
|