openfold/tests/test_utils.py

253 lines
7.6 KiB
Python

# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
import unittest
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
quat_to_rot,
rot_to_quat,
)
from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
X_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, -1],
[0, 1, 0],
]
)
X_NEG_90_ROT = torch.tensor(
[
[1, 0, 0],
[0, 0, 1],
[0, -1, 0],
]
)
class TestUtils(unittest.TestCase):
def test_rigid_from_3_points_shape(self):
batch_size = 2
n_res = 5
x1 = torch.rand((batch_size, n_res, 3))
x2 = torch.rand((batch_size, n_res, 3))
x3 = torch.rand((batch_size, n_res, 3))
r = Rigid.from_3_points(x1, x2, x3)
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
self.assertTrue(torch.all(tra == x2))
def test_rigid_from_4x4(self):
batch_size = 2
transf = [
[1, 0, 0, 1],
[0, 0, -1, 2],
[0, 1, 0, 3],
[0, 0, 0, 1],
]
transf = torch.tensor(transf)
true_rot = transf[:3, :3]
true_trans = transf[:3, 3]
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
r = Rigid.from_tensor_4x4(transf)
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
def test_rigid_shape(self):
batch_size = 2
n = 5
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
self.assertTrue(transf.shape == (batch_size, n))
def test_rigid_cat(self):
batch_size = 2
n = 5
transf = Rigid(
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
torch.rand((batch_size, n, 3))
)
transf_cat = Rigid.cat([transf, transf], dim=0)
transf_rots = transf.get_rots().get_rot_mats()
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
transf_cat = Rigid.cat([transf, transf], dim=1)
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
self.assertTrue(
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
)
def test_rigid_compose(self):
trans_1 = [0, 1, 0]
trans_2 = [0, 0, 1]
r = Rotation(rot_mats=X_90_ROT)
t = torch.tensor(trans_1)
t1 = Rigid(
Rotation(rot_mats=X_90_ROT),
torch.tensor(trans_1)
)
t2 = Rigid(
Rotation(rot_mats=X_NEG_90_ROT),
torch.tensor(trans_2)
)
t3 = t1.compose(t2)
self.assertTrue(
torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
)
self.assertTrue(
torch.all(t3.get_trans() == 0)
)
def test_rigid_apply(self):
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
trans = torch.tensor([1, 1, 1])
trans = torch.stack([trans, trans], dim=0)
t = Rigid(Rotation(rot_mats=rots), trans)
x = torch.arange(30)
x = torch.stack([x, x], dim=0)
x = x.view(2, -1, 3) # [2, 10, 3]
pts = t[..., None].apply(x)
# All simple consequences of the two x-axis rotations
self.assertTrue(torch.all(pts[..., 0] == x[..., 0] + 1))
self.assertTrue(torch.all(pts[0, :, 1] == x[0, :, 2] * -1 + 1))
self.assertTrue(torch.all(pts[1, :, 1] == x[1, :, 2] + 1))
self.assertTrue(torch.all(pts[0, :, 2] == x[0, :, 1] + 1))
self.assertTrue(torch.all(pts[1, :, 2] == x[1, :, 1] * -1 + 1))
def test_quat_to_rot(self):
forty_five = math.pi / 4
quat = torch.tensor([math.cos(forty_five), math.sin(forty_five), 0, 0])
rot = quat_to_rot(quat)
eps = 1e-07
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
def test_rot_to_quat(self):
quat = rot_to_quat(X_90_ROT)
eps = 1e-07
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
def test_chunk_layer_tensor(self):
x = torch.rand(2, 4, 5, 15)
l = torch.nn.Linear(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked == unchunked))
def test_chunk_layer_dict(self):
class LinearDictLayer(torch.nn.Linear):
def forward(self, input):
out = super().forward(input)
return {"out": out, "inner": {"out": out + 1}}
x = torch.rand(2, 4, 5, 15)
l = LinearDictLayer(15, 30)
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
unchunked = l(x)
self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
self.assertTrue(
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
)
def test_chunk_slice_dict(self):
x = torch.rand(3, 4, 3, 5)
x_flat = x.view(-1, 5)
prod = 1
for d in x.shape[:-1]:
prod = prod * d
for i in range(prod):
for j in range(i + 1, prod + 1):
chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
chunked_flattened = x_flat[i:j]
self.assertTrue(torch.all(chunked == chunked_flattened))
@compare_utils.skip_unless_alphafold_installed()
def test_pre_compose_compare(self):
quat = np.random.rand(20, 4)
trans = [np.random.rand(20) for _ in range(3)]
quat_affine = alphafold.model.quat_affine.QuatAffine(
quat, translation=trans
)
update_vec = np.random.rand(20, 6)
new_gt = quat_affine.pre_compose(update_vec)
quat_t = torch.tensor(quat)
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
rigid = Rigid(Rotation(quats=quat_t), trans_t)
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
new_gt_t = torch.stack(
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
)
new_repro_q = new_repro.get_rots().get_quats()
new_repro_t = new_repro.get_trans()
self.assertTrue(
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
)
self.assertTrue(
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
)