Merge pull request #353 from dingquanyu/permutation

Update multi-chain permutation and training codes
This commit is contained in:
Christina Floristean 2023-09-29 10:54:34 -04:00 committed by GitHub
commit 0ca661460d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 257 additions and 177 deletions

View File

@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import (
tensor_tree_map,
)
import random
logging.basicConfig(level=logging.INFO)
@contextlib.contextmanager
def temp_fasta_file(sequence_str):
@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
)
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
alignment_index=alignment_index
)
)
return data
@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
def idx_to_mmcif_id(self, idx):
return self._mmcifs[idx]
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
alignment_index = None
if(self.mode == 'train' or self.mode == 'eval'):
path = os.path.join(self.data_dir, f"{mmcif_id}")
ext = None
@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
if (self._output_raw):
return data
# process all_chain_features
data = self.feature_pipeline.process_features(data,
data,ground_truth = self.feature_pipeline.process_features(data,
mode=self.mode,
is_multimer=True)
# if it's inference mode, only need all_chain_features
data["batch_idx"] = torch.tensor(
[idx for _ in range(data["aatype"].shape[-1])],
dtype=torch.int64,
device=data["aatype"].device)
return data
return data, ground_truth
def __len__(self):
return len(self._chain_ids)
@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
mmcif_id = dataset.idx_to_mmcif_id(i)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
max_resolution=9,
minimum_number_of_residues=5):
max_resolution=9):
selected_idx.append(i)
logging.info(f"Originally {len(mmcif_data_cache)} mmcifs. After filtering: {len(selected_idx)}")
else:
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
return selected_idx

View File

@ -81,7 +81,7 @@ def np_example_to_features(
seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
@ -90,15 +90,29 @@ def np_example_to_features(
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
if(not is_multimer):
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
)
if is_multimer:
if mode == 'train':
features,gt_features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=True
)
return {k: v for k, v in features.items()}, gt_features
else:
features = input_pipeline_multimer.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],
is_training=False
)
return {k: v for k, v in features.items()}
else:
features = input_pipeline_multimer.process_tensors_from_config(
features = input_pipeline.process_tensors_from_config(
tensor_dict,
cfg.common,
cfg[mode],

View File

@ -21,19 +21,8 @@ from openfold.data import (
data_transforms_multimer,
)
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
]
if mode_cfg.supervised:
transforms.extend(
[
def grountruth_transforms_fns():
transforms = [data_transforms.make_atom14_masks,
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms
def nonensembled_transform_fns():
"""Input pipeline data transformers that are not ensembled."""
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks
]
return transforms
@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
return transforms
def prepare_ground_truth_features(tensors):
"""Prepare ground truth features that are only needed for loss calculation during training"""
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
return gt_tensors
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
"""Based on the config, apply filters and transformations to the data."""
if is_training:
gt_tensors= prepare_ground_truth_features(tensors)
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
tensors['aatype'] = tensors['aatype'].to(torch.long)
nonensembled = nonensembled_transform_fns()
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
d["ensemble_index"] = i
return fn(d)
no_templates = True
if("template_aatype" in tensors):
no_templates = tensors["template_aatype"].shape[0] == 0
nonensembled = nonensembled_transform_fns(
common_cfg,
mode_cfg,
)
tensors = compose(nonensembled)(tensors)
if("no_recycling_iters" in tensors):
num_recycling = int(tensors["no_recycling_iters"])
else:
num_recycling = common_cfg.max_recycling_iters
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
)
return tensors
if is_training:
return tensors,gt_tensors
else:
return tensors
@data_transforms.curry1
def compose(x, fs):

View File

@ -1700,9 +1700,6 @@ def compute_rmsd(
atom_mask: torch.Tensor = None,
eps: float = 1e-6,
) -> torch.Tensor:
# shape check
true_atom_pos = true_atom_pos
pred_atom_pos = pred_atom_pos
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
del true_atom_pos
del pred_atom_pos
@ -1784,20 +1781,23 @@ def get_optimal_transform(
return r, x
def get_least_asym_entity_or_longest_length(batch):
def get_least_asym_entity_or_longest_length(batch,input_asym_id):
"""
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor
"""
REQUIRED_FEATURES = ['entity_id','asym_id']
seq_length = batch['seq_length'].item()
# remove padding part before selecting candidate
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device))
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES}
Args:
batch: in this funtion batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features
Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
"""
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
unique_entity_ids = torch.unique(batch["entity_id"])
entity_asym_count = {}
entity_length = {}
@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch):
if len(least_asym_entities) > 1:
least_asym_entities = random.choice(least_asym_entities)
assert len(least_asym_entities)==1
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
# If there is more than one chain in the predicted output that has the same sequence
# as the chosen ground truth anchor, then randomly picke one
if len(best_pred_asym) > 1:
best_pred_asym = random.choice(best_pred_asym)
return least_asym_entities[0], best_pred_asym
least_asym_entities = least_asym_entities[0]
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
return anchor_gt_asym_id, anchor_pred_asym_ids
def greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
@ -1847,6 +1843,7 @@ def greedy_align(
"""
used = [False for _ in range(len(true_ca_poses))]
align = []
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
for cur_asym_id in unique_asym_ids:
i = int(cur_asym_id - 1)
asym_mask = batch["asym_id"] == cur_asym_id
@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
def merge_labels(per_asym_residue_index, labels, align,original_nres):
def merge_labels(per_asym_residue_index,labels, align,original_nres):
"""
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
Merge ground truth labels according to the permutation results
labels: list of original ground truth feats
align: list of tuples, each entry specify the corresponding label of the asym.
@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
cur_out = {}
for i, j in align:
label = labels[j][k]
cur_num_res = labels[j]['aatype'].shape[-1]
# to 1-based
cur_residue_index = per_asym_residue_index[i + 1]
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
continue
else:
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
if k =='all_atom_positions':
dimension_to_merge=1
dimension_to_merge = 1
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
cur_out = [x[1] for x in sorted(cur_out.items())]
if len(cur_out)>0:
@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
def __init__(self, config):
super(AlphaFoldMultimerLoss, self).__init__(config)
self.config = config
@staticmethod
def determine_split_dim(batch)->dict:
"""A method to determine which dimension to split in split_ground_truth_labels"""
padded_dim = batch['aatype'].shape[-1]
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
return dim_dict
@staticmethod
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict):
def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1):
"""
Splits ground truth features according to chains
Returns a list of feature dictionaries with only necessary ground truth features
Returns:
a list of feature dictionaries with only necessary ground truth features
required to finish multi-chain permutation
"""
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
unique_asym_ids.append(padding_asym_id)
asym_id_counts.append(padding_asym_counts)
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
return labels
@staticmethod
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
def get_per_asym_residue_index(features):
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (features["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
return per_asym_residue_index
@staticmethod
def get_entity_2_asym_list(batch):
"""
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
Args:
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
Returns:
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
associated with each entity.
"""
entity_2_asym_list = {}
unique_entity_ids = torch.unique(batch["entity_id"])
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
return entity_2_asym_list
@staticmethod
def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue,
asym_mask,pred_ca_mask):
"""
Calculate an input mask for downstream optimal transformation computation
Args:
true_ca_masks (Tensor): ca mask from ground truth.
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
pred_ca_mask (Tensor): ca mask from predicted structure.
Returns:
input_mask (Tensor): A boolean mask
"""
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_pred_mask = pred_ca_mask[asym_mask]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
return input_mask
@staticmethod
def calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos):
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
anchor_gt_idx,anchor_gt_residue,
asym_mask,
pred_ca_mask)
input_mask = torch.squeeze(input_mask,0)
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
asym_mask = torch.squeeze(asym_mask,0)
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue)
anchor_pred_pos = pred_ca_pos[asym_mask]
r, x = get_optimal_transform(
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
mask=input_mask
)
return r, x
@staticmethod
def multi_chain_perm_align(out, batch,permutate_chains=False):
"""
A class method that first permutate chains in ground truth first
before calculating the loss.
@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
"""
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
per_asym_residue_index = {}
for cur_asym_id in unique_asym_ids:
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
feature, ground_truth = batch
del batch
if permutate_chains:
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
best_rmsd = float('inf')
best_align = None
# First select anchors from predicted structures and ground truths
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
assert isinstance(labels, list)
del ground_truth
anchor_gt_idx = int(anchor_gt_asym) - 1
# Then calculate optimal transform by aligning anchors
ca_idx = rc.atom_order["CA"]
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
unique_entity_ids = torch.unique(batch["entity_id"])
entity_2_asym_list = {}
for cur_ent_id in unique_entity_ids:
ent_mask = batch["entity_id"] == cur_ent_id
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
r, x = get_optimal_transform(
anchor_pred_pos, anchor_true_pos[0],
mask=input_mask[0]
)
del input_mask # just to save memory
del anchor_pred_mask
del anchor_true_mask
gc.collect()
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
del true_ca_poses
gc.collect()
align = greedy_align(
batch,
per_asym_residue_index,
unique_asym_ids,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
del aligned_true_ca_poses, true_ca_masks
del r, x
true_ca_poses = [
l["all_atom_positions"][..., ca_idx, :] for l in labels
] # list([nres, 3])
true_ca_masks = [
l["all_atom_mask"][..., ca_idx].long() for l in labels
] # list([nres,])
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
for candidate_pred_anchor in anchor_pred_asym_ids:
asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool()
anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)]
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
anchor_gt_idx,anchor_gt_residue,
true_ca_masks,pred_ca_mask,
asym_mask,
pred_ca_pos
)
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
align = greedy_align(
feature,
per_asym_residue_index,
entity_2_asym_list,
pred_ca_pos,
pred_ca_mask,
aligned_true_ca_poses,
true_ca_masks,
)
merged_labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=feature['aatype'].shape[-1])
rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
pred_atom_pos = pred_ca_pos,
atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
if rmsd < best_rmsd:
best_rmsd = rmsd
best_align = align
del r,x
del true_ca_masks,aligned_true_ca_poses
del pred_ca_pos, pred_ca_mask
del anchor_pred_pos, anchor_true_pos
gc.collect()
print(f"finished multi-chain permutation and final align is {align}")
else:
align = list(enumerate(range(len(labels))))
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
best_align = list(enumerate(range(len(labels))))
return best_align, per_asym_residue_index
return align, per_asym_residue_index
def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
"""
Overwrite AlphaFoldLoss forward function so that
it first compute multi-chain permutation
@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
batch: a pair of input features and its corresponding ground truth structure
"""
# first check if it is a monomer
features, ground_truth = batch
del batch
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
if not is_monomer:
permutate_chains = True
# first determin which dimension in the tensor to split into individual ground truth labels
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
# Then permutate ground truth chains before calculating the loss
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
permutate_chains=permutate_chains)
# Then permutate ground truth chains before calculating the loss
align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
(features,ground_truth),
permutate_chains=permutate_chains)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
REQUIRED_FEATURES=[i for i in ground_truth.keys()])
# reorder ground truth labels according to permutation results
labels = merge_labels(per_asym_residue_index,labels,align,
original_nres=features['aatype'].shape[-1])
features.update(labels)
if (not _return_breakdown):
cum_loss = self.loss(out, features, _return_breakdown)

View File

@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
batch['all_atom_mask'] = true_atom_mask
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
dim_dict,
permutate_chains=True)
print(f"##### aligns is {aligns}")
possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]]
wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]]
self.assertIn(aligns,possible_outcome)
@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
tensor_to_cuda = lambda t: t.to('cuda')
batch = tensor_tree_map(tensor_to_cuda,batch)
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
batch,
dim_dict,
permutate_chains=True)
print(f"##### aligns is {aligns}")
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
labels = merge_labels(per_asym_residue_index,labels,aligns,
labels = merge_labels(labels,aligns,
original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))

View File

@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
return self.model(batch)
def training_step(self, batch, batch_idx):
features,gt_features = batch
# Log it
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
if(self.ema.device != features["aatype"].device):
self.ema.to(features["aatype"].device)
# Run the model
outputs = self(batch)
outputs = self(features)
# Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch)
features = tensor_tree_map(lambda t: t[..., -1], features)
# Compute loss
loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
outputs, (features,gt_features), _return_breakdown=True
)
# Log it
self._log(loss_breakdown, batch, outputs)
self._log(loss_breakdown, features, outputs)
return loss
def validation_step(self, batch, batch_idx):
features,gt_features = batch
# At the start of validation, load the EMA weights
if(self.cached_weights is None):
# model.state_dict() contains references to model weights rather
@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
self.model.load_state_dict(self.ema.state_dict()["params"])
# Run the model
outputs = self(batch)
outputs = self(features)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0.
features["use_clamped_fape"] = 0.
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
outputs, (features,gt_features), _return_breakdown=True
)
self._log(loss_breakdown, batch, outputs, train=False)
self._log(loss_breakdown, features, outputs, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal