394 lines
14 KiB
Python
394 lines
14 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import torch
|
|
import numpy as np
|
|
import unittest
|
|
|
|
import openfold.data.data_transforms as data_transforms
|
|
from openfold.np.residue_constants import (
|
|
restype_rigid_group_default_frame,
|
|
restype_atom14_to_rigid_group,
|
|
restype_atom14_mask,
|
|
restype_atom14_rigid_group_positions,
|
|
)
|
|
import openfold.utils.feats as feats
|
|
from openfold.utils.rigid_utils import Rotation, Rigid
|
|
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
|
|
from openfold.utils.geometry.rotation_matrix import Rot3Array
|
|
from openfold.utils.geometry.vector import Vec3Array
|
|
from openfold.utils.tensor_utils import (
|
|
tree_map,
|
|
tensor_tree_map,
|
|
)
|
|
import tests.compare_utils as compare_utils
|
|
from tests.config import consts
|
|
from tests.data_utils import random_affines_4x4, random_asym_ids
|
|
|
|
if compare_utils.alphafold_is_installed():
|
|
alphafold = compare_utils.import_alphafold()
|
|
import jax
|
|
import haiku as hk
|
|
|
|
|
|
class TestFeats(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
if compare_utils.alphafold_is_installed():
|
|
if consts.is_multimer:
|
|
cls.am_atom = alphafold.model.all_atom_multimer
|
|
cls.am_fold = alphafold.model.folding_multimer
|
|
cls.am_modules = alphafold.model.modules_multimer
|
|
cls.am_rigid = alphafold.model.geometry
|
|
else:
|
|
cls.am_atom = alphafold.model.all_atom
|
|
cls.am_fold = alphafold.model.folding
|
|
cls.am_modules = alphafold.model.modules
|
|
cls.am_rigid = alphafold.model.r3
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_pseudo_beta_fn_compare(self):
|
|
def test_pbf(aatype, all_atom_pos, all_atom_mask):
|
|
return alphafold.model.modules.pseudo_beta_fn(
|
|
aatype,
|
|
all_atom_pos,
|
|
all_atom_mask,
|
|
)
|
|
|
|
f = hk.transform(test_pbf)
|
|
|
|
n_res = consts.n_res
|
|
|
|
aatype = np.random.randint(0, 22, (n_res,))
|
|
all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32)
|
|
all_atom_mask = np.random.randint(0, 2, (n_res, 37))
|
|
|
|
out_gt_pos, out_gt_mask = f.apply(
|
|
{}, None, aatype, all_atom_pos, all_atom_mask
|
|
)
|
|
out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready()))
|
|
out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready()))
|
|
|
|
out_repro_pos, out_repro_mask = feats.pseudo_beta_fn(
|
|
torch.tensor(aatype).cuda(),
|
|
torch.tensor(all_atom_pos).cuda(),
|
|
torch.tensor(all_atom_mask).cuda(),
|
|
)
|
|
out_repro_pos = out_repro_pos.cpu()
|
|
out_repro_mask = out_repro_mask.cpu()
|
|
|
|
self.assertTrue(
|
|
torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps
|
|
)
|
|
self.assertTrue(
|
|
torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps
|
|
)
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_atom37_to_torsion_angles_compare(self):
|
|
def run_test(aatype, all_atom_pos, all_atom_mask):
|
|
return alphafold.model.all_atom.atom37_to_torsion_angles(
|
|
aatype,
|
|
all_atom_pos,
|
|
all_atom_mask,
|
|
placeholder_for_undefined=False,
|
|
)
|
|
|
|
f = hk.transform(run_test)
|
|
|
|
n_templ = 7
|
|
n_res = 13
|
|
|
|
aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64)
|
|
all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32)
|
|
all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype(
|
|
np.float32
|
|
)
|
|
|
|
out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask)
|
|
out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt)
|
|
|
|
out_repro = data_transforms.atom37_to_torsion_angles()(
|
|
{
|
|
"aatype": torch.as_tensor(aatype).cuda(),
|
|
"all_atom_positions": torch.as_tensor(all_atom_pos).cuda(),
|
|
"all_atom_mask": torch.as_tensor(all_atom_mask).cuda(),
|
|
},
|
|
)
|
|
tasc = out_repro["torsion_angles_sin_cos"].cpu()
|
|
atasc = out_repro["alt_torsion_angles_sin_cos"].cpu()
|
|
tam = out_repro["torsion_angles_mask"].cpu()
|
|
|
|
# This function is extremely sensitive to floating point imprecisions,
|
|
# so it is given much greater latitude in comparison tests.
|
|
self.assertTrue(
|
|
torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc))
|
|
< 0.01
|
|
)
|
|
self.assertTrue(
|
|
torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc))
|
|
< 0.01
|
|
)
|
|
self.assertTrue(
|
|
torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam))
|
|
< consts.eps
|
|
)
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_atom37_to_frames_compare(self):
|
|
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
|
|
if consts.is_multimer:
|
|
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
|
|
return self.am_atom.atom37_to_frames(
|
|
aatype, all_atom_positions, all_atom_mask
|
|
)
|
|
|
|
f = hk.transform(run_atom37_to_frames)
|
|
|
|
n_res = consts.n_res
|
|
|
|
batch = {
|
|
"aatype": np.random.randint(0, 21, (n_res,)),
|
|
"all_atom_positions": np.random.rand(n_res, 37, 3).astype(
|
|
np.float32
|
|
),
|
|
"all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype(
|
|
np.float32
|
|
),
|
|
}
|
|
|
|
out_gt = f.apply({}, None, **batch)
|
|
|
|
if consts.is_multimer:
|
|
batch["asym_id"] = random_asym_ids(n_res)
|
|
to_tensor = (lambda t: torch.tensor(np.array(t))
|
|
if not isinstance(t, self.am_rigid.Rigid3Array)
|
|
else torch.tensor(np.array(t.to_array())))
|
|
else:
|
|
to_tensor = lambda t: torch.tensor(np.array(t))
|
|
|
|
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
|
|
|
|
def rigid3x4_to_4x4(rigid3arr):
|
|
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
|
|
four_by_four[..., :3, :4] = rigid3arr
|
|
four_by_four[..., 3, 3] = 1
|
|
return four_by_four
|
|
|
|
def flat12_to_4x4(flat12):
|
|
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
|
|
trans = flat12[..., 9:]
|
|
|
|
four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4)
|
|
four_by_four[..., :3, :3] = rot
|
|
four_by_four[..., :3, 3] = trans
|
|
four_by_four[..., 3, 3] = 1
|
|
|
|
return four_by_four
|
|
|
|
convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
|
|
|
|
out_gt["rigidgroups_gt_frames"] = convert_func(
|
|
out_gt["rigidgroups_gt_frames"]
|
|
)
|
|
out_gt["rigidgroups_alt_gt_frames"] = convert_func(
|
|
out_gt["rigidgroups_alt_gt_frames"]
|
|
)
|
|
|
|
to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
|
|
batch = tree_map(to_tensor, batch, np.ndarray)
|
|
|
|
out_repro = data_transforms.atom37_to_frames(batch)
|
|
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
|
|
|
|
for k, v in out_gt.items():
|
|
self.assertTrue(
|
|
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
|
|
)
|
|
|
|
def test_torsion_angles_to_frames_shape(self):
|
|
batch_size = 2
|
|
n = 5
|
|
rots = torch.rand((batch_size, n, 3, 3))
|
|
trans = torch.rand((batch_size, n, 3))
|
|
|
|
if consts.is_multimer:
|
|
rotation = Rot3Array.from_array(rots)
|
|
translation = Vec3Array.from_array(trans)
|
|
ts = Rigid3Array(rotation, translation)
|
|
else:
|
|
ts = Rigid(Rotation(rot_mats=rots), trans)
|
|
|
|
angles = torch.rand((batch_size, n, 7, 2))
|
|
|
|
aas = torch.tensor([i % 2 for i in range(n)])
|
|
aas = torch.stack([aas for _ in range(batch_size)])
|
|
|
|
frames = feats.torsion_angles_to_frames(
|
|
ts,
|
|
angles,
|
|
aas,
|
|
torch.tensor(restype_rigid_group_default_frame),
|
|
)
|
|
|
|
self.assertTrue(frames.shape == (batch_size, n, 8))
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_torsion_angles_to_frames_compare(self):
|
|
def run_torsion_angles_to_frames(
|
|
aatype, backb_to_global, torsion_angles_sin_cos
|
|
):
|
|
return self.am_atom.torsion_angles_to_frames(
|
|
aatype,
|
|
backb_to_global,
|
|
torsion_angles_sin_cos,
|
|
)
|
|
|
|
f = hk.transform(run_torsion_angles_to_frames)
|
|
|
|
n_res = consts.n_res
|
|
|
|
aatype = np.random.randint(0, 21, size=(n_res,))
|
|
|
|
affines = random_affines_4x4((n_res,))
|
|
|
|
if consts.is_multimer:
|
|
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
|
|
transformations = Rigid3Array.from_tensor_4x4(
|
|
torch.as_tensor(affines).float()
|
|
)
|
|
else:
|
|
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
|
|
transformations = Rigid.from_tensor_4x4(
|
|
torch.as_tensor(affines).float()
|
|
)
|
|
|
|
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
|
|
|
|
out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos)
|
|
|
|
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
|
|
|
|
out = feats.torsion_angles_to_frames(
|
|
transformations.cuda(),
|
|
torch.as_tensor(torsion_angles_sin_cos).cuda(),
|
|
torch.as_tensor(aatype).cuda(),
|
|
torch.tensor(restype_rigid_group_default_frame).cuda(),
|
|
)
|
|
|
|
# Convert the Rigids to 4x4 transformation tensors
|
|
out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
|
|
out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
|
|
|
|
if consts.is_multimer:
|
|
rots_gt = torch.as_tensor(np.array(out_gt_rot))
|
|
trans_gt = torch.as_tensor(np.array(out_gt_trans))
|
|
else:
|
|
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
|
|
trans_gt = list(
|
|
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
|
|
)
|
|
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
|
|
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
|
|
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
|
|
|
|
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
|
|
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
|
|
bottom_row[..., 3] = 1
|
|
transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2)
|
|
|
|
transforms_repro = out.to_tensor_4x4().cpu()
|
|
|
|
self.assertTrue(
|
|
torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps)
|
|
)
|
|
|
|
def test_frames_and_literature_positions_to_atom14_pos_shape(self):
|
|
batch_size = consts.batch_size
|
|
n_res = consts.n_res
|
|
|
|
rots = torch.rand((batch_size, n_res, 8, 3, 3))
|
|
trans = torch.rand((batch_size, n_res, 8, 3))
|
|
|
|
if consts.is_multimer:
|
|
rotation = Rot3Array.from_array(rots)
|
|
translation = Vec3Array.from_array(trans)
|
|
ts = Rigid3Array(rotation, translation)
|
|
else:
|
|
ts = Rigid(Rotation(rot_mats=rots), trans)
|
|
|
|
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
|
|
|
|
xyz = feats.frames_and_literature_positions_to_atom14_pos(
|
|
ts,
|
|
f,
|
|
torch.tensor(restype_rigid_group_default_frame),
|
|
torch.tensor(restype_atom14_to_rigid_group),
|
|
torch.tensor(restype_atom14_mask),
|
|
torch.tensor(restype_atom14_rigid_group_positions),
|
|
)
|
|
|
|
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_frames_and_literature_positions_to_atom14_pos_compare(self):
|
|
def run_f(aatype, affines):
|
|
return self.am_atom.frames_and_literature_positions_to_atom14_pos(
|
|
aatype, affines
|
|
)
|
|
|
|
f = hk.transform(run_f)
|
|
|
|
n_res = consts.n_res
|
|
|
|
aatype = np.random.randint(0, 21, size=(n_res,))
|
|
|
|
affines = random_affines_4x4((n_res, 8))
|
|
|
|
if consts.is_multimer:
|
|
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
|
|
transformations = Rigid3Array.from_tensor_4x4(
|
|
torch.as_tensor(affines).float()
|
|
)
|
|
else:
|
|
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
|
|
transformations = Rigid.from_tensor_4x4(
|
|
torch.as_tensor(affines).float()
|
|
)
|
|
|
|
out_gt = f.apply({}, None, aatype, rigids)
|
|
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
|
|
|
|
if consts.is_multimer:
|
|
out_gt = torch.as_tensor(np.array(out_gt.to_array()))
|
|
else:
|
|
out_gt = torch.stack(
|
|
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
|
|
)
|
|
|
|
out_repro = feats.frames_and_literature_positions_to_atom14_pos(
|
|
transformations.cuda(),
|
|
torch.as_tensor(aatype).cuda(),
|
|
torch.tensor(restype_rigid_group_default_frame).cuda(),
|
|
torch.tensor(restype_atom14_to_rigid_group).cuda(),
|
|
torch.tensor(restype_atom14_mask).cuda(),
|
|
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
|
|
).cpu()
|
|
|
|
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|