227 lines
10 KiB
Python
227 lines
10 KiB
Python
import gzip
|
|
import pickle
|
|
|
|
import numpy as np
|
|
import torch
|
|
import unittest
|
|
|
|
from openfold.data.data_transforms import make_seq_mask, add_distillation_flag, make_all_atom_aatype, fix_templates_aatype, \
|
|
correct_msa_restypes, squeeze_features, randomly_replace_msa_with_unknown, MSA_FEATURE_NAMES, sample_msa, \
|
|
crop_extra_msa, delete_extra_msa, nearest_neighbor_clusters, make_msa_mask, make_hhblits_profile, make_masked_msa, \
|
|
make_msa_feat, crop_templates, make_atom14_masks
|
|
from tests.config import config
|
|
|
|
|
|
class TestDataTransforms(unittest.TestCase):
|
|
def test_make_seq_mask(self):
|
|
seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0,1)
|
|
seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
|
|
seq_one_hot.scatter_(1, seq, 1)
|
|
protein_aatype = seq_one_hot.clone().detach()
|
|
protein = {'aatype': protein_aatype}
|
|
protein = make_seq_mask(protein)
|
|
assert 'seq_mask' in protein
|
|
assert protein['seq_mask'].shape == torch.Size((seq.shape[0], 20))
|
|
|
|
def test_add_distillation_flag(self):
|
|
protein = {}
|
|
protein = add_distillation_flag.__wrapped__(protein, True)
|
|
assert 'is_distillation' in protein
|
|
assert protein['is_distillation'] is True
|
|
|
|
def test_make_all_atom_aatype(self):
|
|
seq = torch.tensor([range(20)], dtype=torch.int64).transpose(0, 1)
|
|
seq_one_hot = torch.FloatTensor(seq.shape[0], 20).zero_()
|
|
seq_one_hot.scatter_(1, seq, 1)
|
|
protein_aatype = seq_one_hot.clone().detach()
|
|
protein = {'aatype': protein_aatype}
|
|
protein = make_all_atom_aatype(protein)
|
|
assert 'all_atom_aatype' in protein
|
|
assert protein['all_atom_aatype'].shape == protein['aatype'].shape
|
|
|
|
def test_fix_templates_aatype(self):
|
|
template_seq = torch.tensor(list(range(20))*2, dtype=torch.int64)
|
|
template_seq = template_seq.unsqueeze(0).transpose(0, 1)
|
|
template_seq_one_hot = torch.FloatTensor(template_seq.shape[0], 20).zero_()
|
|
template_seq_one_hot.scatter_(1, template_seq, 1)
|
|
template_aatype = template_seq_one_hot.clone().detach().unsqueeze(0)
|
|
protein = {'template_aatype': template_aatype, 'aatype': template_aatype}
|
|
protein = fix_templates_aatype(protein)
|
|
template_seq_ours = torch.tensor([[0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18]*2])
|
|
assert torch.all(torch.eq(protein['template_aatype'], template_seq_ours))
|
|
|
|
def test_correct_msa_restypes(self):
|
|
with open("tests/test_data/features.pkl", 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
|
|
protein = correct_msa_restypes(protein)
|
|
assert torch.all(torch.eq(torch.tensor(features['msa'].shape), torch.tensor(protein['msa'].shape)))
|
|
|
|
def test_squeeze_features(self):
|
|
with open("tests/test_data/features.pkl", "rb") as file:
|
|
features = pickle.load(file)
|
|
|
|
features_list = [
|
|
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
|
|
'superfamily', 'deletion_matrix', 'resolution',
|
|
'between_segment_residues', 'residue_index', 'template_all_atom_mask']
|
|
|
|
protein = {'aatype': torch.tensor(features['aatype'])}
|
|
for k in features_list:
|
|
if k in features:
|
|
if k in ['domain_name', 'sequence']:
|
|
protein[k] = np.expand_dims(features[k], -1)
|
|
else:
|
|
protein[k] = torch.tensor(features[k]).unsqueeze(-1)
|
|
|
|
for k in ['seq_length', 'num_alignments']:
|
|
if k in protein:
|
|
protein[k] = protein[k].clone().detach().unsqueeze(0)
|
|
|
|
protein_squeezed = squeeze_features(protein)
|
|
for k in features_list:
|
|
if k in protein:
|
|
assert protein_squeezed[k].shape == features[k].shape
|
|
|
|
def test_randomly_replace_msa_with_unknown(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {'msa': torch.tensor(features['msa']),
|
|
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
|
|
replace_proportion = 0.15
|
|
x_idx = 20
|
|
protein = randomly_replace_msa_with_unknown.__wrapped__(protein, replace_proportion)
|
|
unknown_proportion_in_msa = torch.bincount(protein['msa'].flatten()) / torch.numel(protein['msa'])
|
|
unknown_proportion_in_seq = torch.bincount(protein['aatype'].flatten()) / torch.numel(protein['aatype'])
|
|
|
|
def test_sample_msa(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
max_seq = 1000
|
|
keep_extra = True
|
|
protein = {}
|
|
for k in MSA_FEATURE_NAMES:
|
|
if k in features:
|
|
protein[k] = torch.tensor(features[k])
|
|
|
|
protein_processed = sample_msa.__wrapped__(protein.copy(), max_seq, keep_extra)
|
|
for k in MSA_FEATURE_NAMES:
|
|
if k in protein and keep_extra:
|
|
assert protein_processed[k].shape[0] == min(protein[k].shape[0], max_seq)
|
|
assert 'extra_'+k in protein_processed
|
|
assert protein_processed['extra_'+k].shape[0] == protein[k].shape[0] - min(protein[k].shape[0], max_seq)
|
|
|
|
def test_crop_extra_msa(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
max_extra_msa = 10
|
|
protein = {'extra_msa': torch.tensor(features['msa'])}
|
|
num_seq = protein["extra_msa"].shape[0]
|
|
|
|
protein = crop_extra_msa.__wrapped__(protein, max_extra_msa)
|
|
for k in MSA_FEATURE_NAMES:
|
|
if "extra_" + k in protein:
|
|
assert protein["extra_" + k].shape[0] == min(max_extra_msa, num_seq)
|
|
|
|
def test_delete_extra_msa(self):
|
|
protein = {'extra_msa': torch.rand((512, 100, 23))}
|
|
extra_msa_has_deletion_shape = list(protein['extra_msa'].shape)
|
|
extra_msa_has_deletion_shape[2] = 1
|
|
protein['extra_deletion_matrix'] = torch.rand(extra_msa_has_deletion_shape)
|
|
protein = delete_extra_msa(protein)
|
|
for k in MSA_FEATURE_NAMES:
|
|
assert 'extra_' + k not in protein
|
|
assert 'extra_msa' not in protein
|
|
|
|
def test_nearest_neighbor_clusters(self):
|
|
with gzip.open('tests/test_data/sample_feats.pickle.gz', 'rb') as f:
|
|
features = pickle.load(f)
|
|
|
|
protein = {'msa': torch.tensor(features['true_msa'][0], dtype=torch.int64),
|
|
'msa_mask': torch.tensor(features['msa_mask'][0], dtype=torch.int64),
|
|
'extra_msa': torch.tensor(features['extra_msa'][0], dtype=torch.int64),
|
|
'extra_msa_mask': torch.tensor(features['extra_msa_mask'][0], dtype=torch.int64)}
|
|
protein = nearest_neighbor_clusters.__wrapped__(protein, 0)
|
|
assert 'extra_cluster_assignment' in protein
|
|
|
|
def test_make_msa_mask(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
msa_mat = torch.tensor(features['msa'])
|
|
protein = {'msa': msa_mat}
|
|
protein = make_msa_mask(protein)
|
|
assert 'msa_row_mask' in protein
|
|
assert protein['msa_row_mask'].shape[0] == msa_mat.shape[0]
|
|
|
|
def test_make_hhblits_profile(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {'msa': torch.tensor(features['msa'], dtype=torch.int64)}
|
|
protein = make_hhblits_profile(protein)
|
|
assert 'hhblits_profile' in protein
|
|
assert protein['hhblits_profile'].shape == torch.Size((protein['msa'].shape[1], 22))
|
|
|
|
def test_make_masked_msa(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {
|
|
'msa': torch.tensor(features['msa'], dtype=torch.int64),
|
|
'aatype': torch.tensor(features['aatype'], dtype=torch.int64),
|
|
}
|
|
protein = make_hhblits_profile(protein)
|
|
masked_msa_config = config.data.common.masked_msa
|
|
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
|
|
assert 'bert_mask' in protein
|
|
assert 'true_msa' in protein
|
|
assert 'msa' in protein
|
|
assert protein['bert_mask'].sum() >= 0
|
|
assert torch.all(torch.eq(
|
|
protein['true_msa'] * (1-protein['bert_mask']), protein['msa'] * (1-protein['bert_mask'])))
|
|
|
|
def test_make_msa_feat(self):
|
|
with open('tests/test_data/features.pkl', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {'between_segment_residues': torch.tensor(features['between_segment_residues']),
|
|
'msa': torch.tensor(features['msa'], dtype=torch.int64),
|
|
'deletion_matrix': torch.tensor(features['deletion_matrix_int']),
|
|
'aatype': torch.argmax(torch.tensor(features['aatype']), dim=1)}
|
|
protein = make_msa_feat.__wrapped__(protein)
|
|
assert 'msa_feat' in protein
|
|
assert 'target_feat' in protein
|
|
assert protein['target_feat'].shape == torch.Size((protein['msa'].shape[1], 22))
|
|
assert protein['msa_feat'].shape == torch.Size((*protein['msa'].shape, 25))
|
|
|
|
def test_crop_templates(self):
|
|
with gzip.open('tests/test_data/sample_feats.pickle.gz', 'rb') as f:
|
|
features = pickle.load(f)
|
|
|
|
protein = {'template_aatype': torch.tensor(features['true_msa'][0]),
|
|
'template_all_atom_masks': torch.tensor(features['msa_mask'][0])}
|
|
max_templates = 2
|
|
protein = crop_templates.__wrapped__(protein, max_templates)
|
|
assert protein['template_aatype'].shape[0] == max_templates
|
|
assert protein['template_all_atom_masks'].shape[0] == max_templates
|
|
|
|
def test_make_atom14_masks(self):
|
|
with gzip.open('tests/test_data/sample_feats.pickle.gz', 'rb') as file:
|
|
features = pickle.load(file)
|
|
|
|
protein = {'aatype': torch.tensor(features['aatype'][0])}
|
|
protein = make_atom14_masks(protein)
|
|
assert 'atom14_atom_exists' in protein
|
|
assert 'residx_atom14_to_atom37' in protein
|
|
assert 'residx_atom37_to_atom14' in protein
|
|
assert 'atom37_atom_exists' in protein
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|