253 lines
7.6 KiB
Python
253 lines
7.6 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
import numpy as np
|
|
import torch
|
|
import unittest
|
|
|
|
from openfold.utils.rigid_utils import (
|
|
Rotation,
|
|
Rigid,
|
|
quat_to_rot,
|
|
rot_to_quat,
|
|
)
|
|
from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
|
|
import tests.compare_utils as compare_utils
|
|
from tests.config import consts
|
|
|
|
if compare_utils.alphafold_is_installed():
|
|
alphafold = compare_utils.import_alphafold()
|
|
import jax
|
|
import haiku as hk
|
|
|
|
|
|
X_90_ROT = torch.tensor(
|
|
[
|
|
[1, 0, 0],
|
|
[0, 0, -1],
|
|
[0, 1, 0],
|
|
]
|
|
)
|
|
|
|
X_NEG_90_ROT = torch.tensor(
|
|
[
|
|
[1, 0, 0],
|
|
[0, 0, 1],
|
|
[0, -1, 0],
|
|
]
|
|
)
|
|
|
|
|
|
class TestUtils(unittest.TestCase):
|
|
def test_rigid_from_3_points_shape(self):
|
|
batch_size = 2
|
|
n_res = 5
|
|
|
|
x1 = torch.rand((batch_size, n_res, 3))
|
|
x2 = torch.rand((batch_size, n_res, 3))
|
|
x3 = torch.rand((batch_size, n_res, 3))
|
|
|
|
r = Rigid.from_3_points(x1, x2, x3)
|
|
|
|
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
|
|
|
|
self.assertTrue(rot.shape == (batch_size, n_res, 3, 3))
|
|
self.assertTrue(torch.all(tra == x2))
|
|
|
|
def test_rigid_from_4x4(self):
|
|
batch_size = 2
|
|
transf = [
|
|
[1, 0, 0, 1],
|
|
[0, 0, -1, 2],
|
|
[0, 1, 0, 3],
|
|
[0, 0, 0, 1],
|
|
]
|
|
transf = torch.tensor(transf)
|
|
|
|
true_rot = transf[:3, :3]
|
|
true_trans = transf[:3, 3]
|
|
|
|
transf = torch.stack([transf for _ in range(batch_size)], dim=0)
|
|
|
|
r = Rigid.from_tensor_4x4(transf)
|
|
|
|
rot, tra = r.get_rots().get_rot_mats(), r.get_trans()
|
|
|
|
self.assertTrue(torch.all(rot == true_rot.unsqueeze(0)))
|
|
self.assertTrue(torch.all(tra == true_trans.unsqueeze(0)))
|
|
|
|
def test_rigid_shape(self):
|
|
batch_size = 2
|
|
n = 5
|
|
transf = Rigid(
|
|
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
|
|
torch.rand((batch_size, n, 3))
|
|
)
|
|
|
|
self.assertTrue(transf.shape == (batch_size, n))
|
|
|
|
def test_rigid_cat(self):
|
|
batch_size = 2
|
|
n = 5
|
|
transf = Rigid(
|
|
Rotation(rot_mats=torch.rand((batch_size, n, 3, 3))),
|
|
torch.rand((batch_size, n, 3))
|
|
)
|
|
|
|
transf_cat = Rigid.cat([transf, transf], dim=0)
|
|
|
|
transf_rots = transf.get_rots().get_rot_mats()
|
|
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
|
|
|
|
self.assertTrue(transf_cat_rots.shape == (batch_size * 2, n, 3, 3))
|
|
|
|
transf_cat = Rigid.cat([transf, transf], dim=1)
|
|
transf_cat_rots = transf_cat.get_rots().get_rot_mats()
|
|
|
|
self.assertTrue(transf_cat_rots.shape == (batch_size, n * 2, 3, 3))
|
|
|
|
self.assertTrue(torch.all(transf_cat_rots[:, :n] == transf_rots))
|
|
self.assertTrue(
|
|
torch.all(transf_cat.get_trans()[:, :n] == transf.get_trans())
|
|
)
|
|
|
|
def test_rigid_compose(self):
|
|
trans_1 = [0, 1, 0]
|
|
trans_2 = [0, 0, 1]
|
|
|
|
r = Rotation(rot_mats=X_90_ROT)
|
|
t = torch.tensor(trans_1)
|
|
|
|
t1 = Rigid(
|
|
Rotation(rot_mats=X_90_ROT),
|
|
torch.tensor(trans_1)
|
|
)
|
|
t2 = Rigid(
|
|
Rotation(rot_mats=X_NEG_90_ROT),
|
|
torch.tensor(trans_2)
|
|
)
|
|
|
|
t3 = t1.compose(t2)
|
|
|
|
self.assertTrue(
|
|
torch.all(t3.get_rots().get_rot_mats() == torch.eye(3))
|
|
)
|
|
self.assertTrue(
|
|
torch.all(t3.get_trans() == 0)
|
|
)
|
|
|
|
def test_rigid_apply(self):
|
|
rots = torch.stack([X_90_ROT, X_NEG_90_ROT], dim=0)
|
|
trans = torch.tensor([1, 1, 1])
|
|
trans = torch.stack([trans, trans], dim=0)
|
|
|
|
t = Rigid(Rotation(rot_mats=rots), trans)
|
|
|
|
x = torch.arange(30)
|
|
x = torch.stack([x, x], dim=0)
|
|
x = x.view(2, -1, 3) # [2, 10, 3]
|
|
|
|
pts = t[..., None].apply(x)
|
|
|
|
# All simple consequences of the two x-axis rotations
|
|
self.assertTrue(torch.all(pts[..., 0] == x[..., 0] + 1))
|
|
self.assertTrue(torch.all(pts[0, :, 1] == x[0, :, 2] * -1 + 1))
|
|
self.assertTrue(torch.all(pts[1, :, 1] == x[1, :, 2] + 1))
|
|
self.assertTrue(torch.all(pts[0, :, 2] == x[0, :, 1] + 1))
|
|
self.assertTrue(torch.all(pts[1, :, 2] == x[1, :, 1] * -1 + 1))
|
|
|
|
def test_quat_to_rot(self):
|
|
forty_five = math.pi / 4
|
|
quat = torch.tensor([math.cos(forty_five), math.sin(forty_five), 0, 0])
|
|
rot = quat_to_rot(quat)
|
|
eps = 1e-07
|
|
self.assertTrue(torch.all(torch.abs(rot - X_90_ROT) < eps))
|
|
|
|
def test_rot_to_quat(self):
|
|
quat = rot_to_quat(X_90_ROT)
|
|
eps = 1e-07
|
|
ans = torch.tensor([math.sqrt(0.5), math.sqrt(0.5), 0., 0.])
|
|
self.assertTrue(torch.all(torch.abs(quat - ans) < eps))
|
|
|
|
def test_chunk_layer_tensor(self):
|
|
x = torch.rand(2, 4, 5, 15)
|
|
l = torch.nn.Linear(15, 30)
|
|
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
|
|
unchunked = l(x)
|
|
|
|
self.assertTrue(torch.all(chunked == unchunked))
|
|
|
|
def test_chunk_layer_dict(self):
|
|
class LinearDictLayer(torch.nn.Linear):
|
|
def forward(self, input):
|
|
out = super().forward(input)
|
|
return {"out": out, "inner": {"out": out + 1}}
|
|
|
|
x = torch.rand(2, 4, 5, 15)
|
|
l = LinearDictLayer(15, 30)
|
|
|
|
chunked = chunk_layer(l, {"input": x}, chunk_size=4, no_batch_dims=3)
|
|
unchunked = l(x)
|
|
|
|
self.assertTrue(torch.all(chunked["out"] == unchunked["out"]))
|
|
self.assertTrue(
|
|
torch.all(chunked["inner"]["out"] == unchunked["inner"]["out"])
|
|
)
|
|
|
|
def test_chunk_slice_dict(self):
|
|
x = torch.rand(3, 4, 3, 5)
|
|
x_flat = x.view(-1, 5)
|
|
|
|
prod = 1
|
|
for d in x.shape[:-1]:
|
|
prod = prod * d
|
|
|
|
for i in range(prod):
|
|
for j in range(i + 1, prod + 1):
|
|
chunked = _chunk_slice(x, i, j, len(x.shape[:-1]))
|
|
chunked_flattened = x_flat[i:j]
|
|
|
|
self.assertTrue(torch.all(chunked == chunked_flattened))
|
|
|
|
@compare_utils.skip_unless_alphafold_installed()
|
|
def test_pre_compose_compare(self):
|
|
quat = np.random.rand(20, 4)
|
|
trans = [np.random.rand(20) for _ in range(3)]
|
|
quat_affine = alphafold.model.quat_affine.QuatAffine(
|
|
quat, translation=trans
|
|
)
|
|
|
|
update_vec = np.random.rand(20, 6)
|
|
new_gt = quat_affine.pre_compose(update_vec)
|
|
|
|
quat_t = torch.tensor(quat)
|
|
trans_t = torch.stack([torch.tensor(t) for t in trans], dim=-1)
|
|
rigid = Rigid(Rotation(quats=quat_t), trans_t)
|
|
new_repro = rigid.compose_q_update_vec(torch.tensor(update_vec))
|
|
|
|
new_gt_q = torch.tensor(np.array(new_gt.quaternion))
|
|
new_gt_t = torch.stack(
|
|
[torch.tensor(np.array(t)) for t in new_gt.translation], dim=-1
|
|
)
|
|
new_repro_q = new_repro.get_rots().get_quats()
|
|
new_repro_t = new_repro.get_trans()
|
|
|
|
self.assertTrue(
|
|
torch.max(torch.abs(new_gt_q - new_repro_q)) < consts.eps
|
|
)
|
|
self.assertTrue(
|
|
torch.max(torch.abs(new_gt_t - new_repro_t)) < consts.eps
|
|
)
|