Merge pull request #353 from dingquanyu/permutation
Update multi-chain permutation and training codes
This commit is contained in:
commit
0ca661460d
|
@ -24,7 +24,7 @@ from openfold.utils.tensor_utils import (
|
|||
tensor_tree_map,
|
||||
)
|
||||
import random
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_fasta_file(sequence_str):
|
||||
|
@ -443,7 +443,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
|
||||
|
||||
def _parse_mmcif(self, path, file_id, alignment_dir, alignment_index):
|
||||
def _parse_mmcif(self, path, file_id,alignment_dir, alignment_index):
|
||||
with open(path, 'r') as f:
|
||||
mmcif_string = f.read()
|
||||
|
||||
|
@ -462,7 +462,7 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
|
|||
mmcif=mmcif_object,
|
||||
alignment_dir=alignment_dir,
|
||||
alignment_index=alignment_index
|
||||
)
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
@ -471,10 +471,11 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
|
|||
|
||||
def idx_to_mmcif_id(self, idx):
|
||||
return self._mmcifs[idx]
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
mmcif_id = self.idx_to_mmcif_id(idx)
|
||||
alignment_index = None
|
||||
|
||||
if(self.mode == 'train' or self.mode == 'eval'):
|
||||
path = os.path.join(self.data_dir, f"{mmcif_id}")
|
||||
ext = None
|
||||
|
@ -503,19 +504,19 @@ class OpenFoldSingleMultimerDataset(torch.utils.data.Dataset):
|
|||
|
||||
if (self._output_raw):
|
||||
return data
|
||||
|
||||
|
||||
# process all_chain_features
|
||||
data = self.feature_pipeline.process_features(data,
|
||||
data,ground_truth = self.feature_pipeline.process_features(data,
|
||||
mode=self.mode,
|
||||
is_multimer=True)
|
||||
|
||||
|
||||
# if it's inference mode, only need all_chain_features
|
||||
data["batch_idx"] = torch.tensor(
|
||||
[idx for _ in range(data["aatype"].shape[-1])],
|
||||
dtype=torch.int64,
|
||||
device=data["aatype"].device)
|
||||
|
||||
return data
|
||||
return data, ground_truth
|
||||
|
||||
def __len__(self):
|
||||
return len(self._chain_ids)
|
||||
|
@ -723,9 +724,9 @@ class OpenFoldMultimerDataset(torch.utils.data.Dataset):
|
|||
mmcif_id = dataset.idx_to_mmcif_id(i)
|
||||
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
|
||||
if deterministic_multimer_train_filter(mmcif_data_cache_entry,
|
||||
max_resolution=9,
|
||||
minimum_number_of_residues=5):
|
||||
max_resolution=9):
|
||||
selected_idx.append(i)
|
||||
logging.info(f"Originally {len(mmcif_data_cache)} mmcifs. After filtering: {len(selected_idx)}")
|
||||
else:
|
||||
selected_idx = list(range(len(dataset._mmcif_id_to_idx_dict)))
|
||||
return selected_idx
|
||||
|
|
|
@ -81,7 +81,7 @@ def np_example_to_features(
|
|||
seq_length = np_example["seq_length"]
|
||||
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
|
||||
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
|
||||
|
||||
|
||||
if "deletion_matrix_int" in np_example:
|
||||
np_example["deletion_matrix"] = np_example.pop(
|
||||
"deletion_matrix_int"
|
||||
|
@ -90,15 +90,29 @@ def np_example_to_features(
|
|||
tensor_dict = np_to_tensor_dict(
|
||||
np_example=np_example, features=feature_names
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
if(not is_multimer):
|
||||
features = input_pipeline.process_tensors_from_config(
|
||||
tensor_dict,
|
||||
cfg.common,
|
||||
cfg[mode],
|
||||
)
|
||||
if is_multimer:
|
||||
if mode == 'train':
|
||||
features,gt_features = input_pipeline_multimer.process_tensors_from_config(
|
||||
tensor_dict,
|
||||
cfg.common,
|
||||
cfg[mode],
|
||||
is_training=True
|
||||
)
|
||||
|
||||
return {k: v for k, v in features.items()}, gt_features
|
||||
else:
|
||||
features = input_pipeline_multimer.process_tensors_from_config(
|
||||
tensor_dict,
|
||||
cfg.common,
|
||||
cfg[mode],
|
||||
is_training=False
|
||||
)
|
||||
return {k: v for k, v in features.items()}
|
||||
|
||||
else:
|
||||
features = input_pipeline_multimer.process_tensors_from_config(
|
||||
features = input_pipeline.process_tensors_from_config(
|
||||
tensor_dict,
|
||||
cfg.common,
|
||||
cfg[mode],
|
||||
|
|
|
@ -21,19 +21,8 @@ from openfold.data import (
|
|||
data_transforms_multimer,
|
||||
)
|
||||
|
||||
|
||||
def nonensembled_transform_fns(common_cfg, mode_cfg):
|
||||
"""Input pipeline data transformers that are not ensembled."""
|
||||
transforms = [
|
||||
data_transforms.cast_to_64bit_ints,
|
||||
data_transforms_multimer.make_msa_profile,
|
||||
data_transforms_multimer.create_target_feat,
|
||||
data_transforms.make_atom14_masks,
|
||||
]
|
||||
|
||||
if mode_cfg.supervised:
|
||||
transforms.extend(
|
||||
[
|
||||
def grountruth_transforms_fns():
|
||||
transforms = [data_transforms.make_atom14_masks,
|
||||
data_transforms.make_atom14_positions,
|
||||
data_transforms.atom37_to_frames,
|
||||
data_transforms.atom37_to_torsion_angles(""),
|
||||
|
@ -41,7 +30,16 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
|
|||
data_transforms.get_backbone_frames,
|
||||
data_transforms.get_chi_angles,
|
||||
]
|
||||
)
|
||||
return transforms
|
||||
|
||||
def nonensembled_transform_fns():
|
||||
"""Input pipeline data transformers that are not ensembled."""
|
||||
transforms = [
|
||||
data_transforms.cast_to_64bit_ints,
|
||||
data_transforms_multimer.make_msa_profile,
|
||||
data_transforms_multimer.create_target_feat,
|
||||
data_transforms.make_atom14_masks
|
||||
]
|
||||
|
||||
return transforms
|
||||
|
||||
|
@ -114,11 +112,29 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
|
|||
|
||||
return transforms
|
||||
|
||||
def prepare_ground_truth_features(tensors):
|
||||
"""Prepare ground truth features that are only needed for loss calculation during training"""
|
||||
|
||||
def process_tensors_from_config(tensors, common_cfg, mode_cfg):
|
||||
GROUNDTRUTH_FEATURES=['all_atom_mask', 'all_atom_positions','asym_id','sym_id','entity_id']
|
||||
gt_tensors = {k:v for k,v in tensors.items() if k in GROUNDTRUTH_FEATURES}
|
||||
gt_tensors['aatype'] = tensors['aatype'].to(torch.long)
|
||||
gt_tensors = compose(grountruth_transforms_fns())(gt_tensors)
|
||||
return gt_tensors
|
||||
|
||||
def process_tensors_from_config(tensors, common_cfg, mode_cfg,is_training=False):
|
||||
"""Based on the config, apply filters and transformations to the data."""
|
||||
|
||||
if is_training:
|
||||
gt_tensors= prepare_ground_truth_features(tensors)
|
||||
|
||||
ensemble_seed = random.randint(0, torch.iinfo(torch.int32).max)
|
||||
tensors['aatype'] = tensors['aatype'].to(torch.long)
|
||||
nonensembled = nonensembled_transform_fns()
|
||||
tensors = compose(nonensembled)(tensors)
|
||||
if("no_recycling_iters" in tensors):
|
||||
num_recycling = int(tensors["no_recycling_iters"])
|
||||
else:
|
||||
num_recycling = common_cfg.max_recycling_iters
|
||||
|
||||
def wrap_ensemble_fn(data, i):
|
||||
"""Function to be mapped over the ensemble dimension."""
|
||||
|
@ -132,28 +148,14 @@ def process_tensors_from_config(tensors, common_cfg, mode_cfg):
|
|||
d["ensemble_index"] = i
|
||||
return fn(d)
|
||||
|
||||
no_templates = True
|
||||
if("template_aatype" in tensors):
|
||||
no_templates = tensors["template_aatype"].shape[0] == 0
|
||||
|
||||
nonensembled = nonensembled_transform_fns(
|
||||
common_cfg,
|
||||
mode_cfg,
|
||||
)
|
||||
|
||||
tensors = compose(nonensembled)(tensors)
|
||||
|
||||
if("no_recycling_iters" in tensors):
|
||||
num_recycling = int(tensors["no_recycling_iters"])
|
||||
else:
|
||||
num_recycling = common_cfg.max_recycling_iters
|
||||
|
||||
tensors = map_fn(
|
||||
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_recycling + 1)
|
||||
)
|
||||
|
||||
return tensors
|
||||
|
||||
if is_training:
|
||||
return tensors,gt_tensors
|
||||
else:
|
||||
return tensors
|
||||
|
||||
@data_transforms.curry1
|
||||
def compose(x, fs):
|
||||
|
|
|
@ -1700,9 +1700,6 @@ def compute_rmsd(
|
|||
atom_mask: torch.Tensor = None,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
# shape check
|
||||
true_atom_pos = true_atom_pos
|
||||
pred_atom_pos = pred_atom_pos
|
||||
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
|
||||
del true_atom_pos
|
||||
del pred_atom_pos
|
||||
|
@ -1784,20 +1781,23 @@ def get_optimal_transform(
|
|||
return r, x
|
||||
|
||||
|
||||
def get_least_asym_entity_or_longest_length(batch):
|
||||
def get_least_asym_entity_or_longest_length(batch,input_asym_id):
|
||||
"""
|
||||
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select
|
||||
one of the A as anchor
|
||||
|
||||
If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
|
||||
then choose one of the corresponding subunits as anchor
|
||||
"""
|
||||
REQUIRED_FEATURES = ['entity_id','asym_id']
|
||||
seq_length = batch['seq_length'].item()
|
||||
|
||||
# remove padding part before selecting candidate
|
||||
remove_padding = lambda t: torch.index_select(t,dim=1,index=torch.arange(seq_length,device=t.device))
|
||||
batch = {k:tensor_tree_map(remove_padding,batch[k]) for k in REQUIRED_FEATURES}
|
||||
Args:
|
||||
batch: in this funtion batch is the full ground truth features
|
||||
input_asym_id: A list of aym_ids that are in the cropped input features
|
||||
|
||||
Return:
|
||||
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
|
||||
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
|
||||
"""
|
||||
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(batch)
|
||||
unique_entity_ids = torch.unique(batch["entity_id"])
|
||||
entity_asym_count = {}
|
||||
entity_length = {}
|
||||
|
@ -1822,19 +1822,15 @@ def get_least_asym_entity_or_longest_length(batch):
|
|||
if len(least_asym_entities) > 1:
|
||||
least_asym_entities = random.choice(least_asym_entities)
|
||||
assert len(least_asym_entities)==1
|
||||
best_pred_asym = torch.unique(batch["asym_id"][batch["entity_id"] == least_asym_entities[0]])
|
||||
|
||||
# If there is more than one chain in the predicted output that has the same sequence
|
||||
# as the chosen ground truth anchor, then randomly picke one
|
||||
if len(best_pred_asym) > 1:
|
||||
best_pred_asym = random.choice(best_pred_asym)
|
||||
return least_asym_entities[0], best_pred_asym
|
||||
least_asym_entities = least_asym_entities[0]
|
||||
|
||||
anchor_gt_asym_id = random.choice(entity_2_asym_list[least_asym_entities])
|
||||
anchor_pred_asym_ids = [id for id in entity_2_asym_list[least_asym_entities] if id in input_asym_id]
|
||||
return anchor_gt_asym_id, anchor_pred_asym_ids
|
||||
|
||||
def greedy_align(
|
||||
batch,
|
||||
per_asym_residue_index,
|
||||
unique_asym_ids,
|
||||
entity_2_asym_list,
|
||||
pred_ca_pos,
|
||||
pred_ca_mask,
|
||||
|
@ -1847,6 +1843,7 @@ def greedy_align(
|
|||
"""
|
||||
used = [False for _ in range(len(true_ca_poses))]
|
||||
align = []
|
||||
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
|
||||
for cur_asym_id in unique_asym_ids:
|
||||
i = int(cur_asym_id - 1)
|
||||
asym_mask = batch["asym_id"] == cur_asym_id
|
||||
|
@ -1884,9 +1881,10 @@ def pad_features(feature_tensor,nres_pad,pad_dim):
|
|||
return torch.concat((feature_tensor,padding_tensor),dim=pad_dim)
|
||||
|
||||
|
||||
def merge_labels(per_asym_residue_index, labels, align,original_nres):
|
||||
def merge_labels(per_asym_residue_index,labels, align,original_nres):
|
||||
"""
|
||||
per_asym_residue_index: A dictionary that record which asym_id corresponds to which regions of residues in the multimer complex.
|
||||
Merge ground truth labels according to the permutation results
|
||||
|
||||
labels: list of original ground truth feats
|
||||
align: list of tuples, each entry specify the corresponding label of the asym.
|
||||
|
||||
|
@ -1898,15 +1896,12 @@ def merge_labels(per_asym_residue_index, labels, align,original_nres):
|
|||
cur_out = {}
|
||||
for i, j in align:
|
||||
label = labels[j][k]
|
||||
cur_num_res = labels[j]['aatype'].shape[-1]
|
||||
# to 1-based
|
||||
cur_residue_index = per_asym_residue_index[i + 1]
|
||||
if len(v.shape)<=1 or "template" in k or "row_mask" in k :
|
||||
continue
|
||||
else:
|
||||
dimension_to_merge = label.shape.index(cur_num_res) if cur_num_res in label.shape else 0
|
||||
if k =='all_atom_positions':
|
||||
dimension_to_merge=1
|
||||
dimension_to_merge = 1
|
||||
cur_out[i] = label.index_select(dimension_to_merge,cur_residue_index)
|
||||
cur_out = [x[1] for x in sorted(cur_out.items())]
|
||||
if len(cur_out)>0:
|
||||
|
@ -2037,19 +2032,14 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
|
|||
def __init__(self, config):
|
||||
super(AlphaFoldMultimerLoss, self).__init__(config)
|
||||
self.config = config
|
||||
@staticmethod
|
||||
def determine_split_dim(batch)->dict:
|
||||
"""A method to determine which dimension to split in split_ground_truth_labels"""
|
||||
padded_dim = batch['aatype'].shape[-1]
|
||||
dim_dict = {k:list(v.shape).index(padded_dim) for k,v in batch.items() if padded_dim in v.shape}
|
||||
return dim_dict
|
||||
|
||||
@staticmethod
|
||||
def split_ground_truth_labels(batch,REQUIRED_FEATURES,dim_dict):
|
||||
def split_ground_truth_labels(batch,REQUIRED_FEATURES,split_dim=1):
|
||||
"""
|
||||
Splits ground truth features according to chains
|
||||
|
||||
Returns a list of feature dictionaries with only necessary ground truth features
|
||||
Returns:
|
||||
a list of feature dictionaries with only necessary ground truth features
|
||||
required to finish multi-chain permutation
|
||||
"""
|
||||
unique_asym_ids, asym_id_counts = torch.unique(batch["asym_id"], sorted=False,return_counts=True)
|
||||
|
@ -2061,11 +2051,85 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
|
|||
unique_asym_ids.append(padding_asym_id)
|
||||
asym_id_counts.append(padding_asym_counts)
|
||||
|
||||
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=dim_dict[k])] for k, value in batch.items() if k in REQUIRED_FEATURES])))
|
||||
labels = list(map(dict, zip(*[[(k, v) for v in torch.split(value, asym_id_counts, dim=split_dim)] for k, value in batch.items() if k in REQUIRED_FEATURES])))
|
||||
return labels
|
||||
|
||||
@staticmethod
|
||||
def multi_chain_perm_align(out, batch, dim_dict,permutate_chains=True):
|
||||
def get_per_asym_residue_index(features):
|
||||
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i!=0]
|
||||
per_asym_residue_index = {}
|
||||
for cur_asym_id in unique_asym_ids:
|
||||
asym_mask = (features["asym_id"] == cur_asym_id).bool()
|
||||
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(features["residue_index"], asym_mask)
|
||||
|
||||
return per_asym_residue_index
|
||||
|
||||
@staticmethod
|
||||
def get_entity_2_asym_list(batch):
|
||||
"""
|
||||
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
|
||||
|
||||
Args:
|
||||
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
|
||||
|
||||
Returns:
|
||||
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
|
||||
associated with each entity.
|
||||
"""
|
||||
entity_2_asym_list = {}
|
||||
unique_entity_ids = torch.unique(batch["entity_id"])
|
||||
for cur_ent_id in unique_entity_ids:
|
||||
ent_mask = batch["entity_id"] == cur_ent_id
|
||||
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
|
||||
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
|
||||
return entity_2_asym_list
|
||||
|
||||
@staticmethod
|
||||
def calculate_input_mask(true_ca_masks,anchor_gt_idx,anchor_gt_residue,
|
||||
asym_mask,pred_ca_mask):
|
||||
"""
|
||||
Calculate an input mask for downstream optimal transformation computation
|
||||
|
||||
Args:
|
||||
true_ca_masks (Tensor): ca mask from ground truth.
|
||||
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
|
||||
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
|
||||
pred_ca_mask (Tensor): ca mask from predicted structure.
|
||||
|
||||
Returns:
|
||||
input_mask (Tensor): A boolean mask
|
||||
"""
|
||||
pred_ca_mask = torch.squeeze(pred_ca_mask,0)
|
||||
asym_mask = torch.squeeze(asym_mask,0)
|
||||
anchor_pred_mask = pred_ca_mask[asym_mask]
|
||||
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx],1,anchor_gt_residue)
|
||||
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
|
||||
return input_mask
|
||||
|
||||
@staticmethod
|
||||
def calculate_optimal_transform(true_ca_poses,
|
||||
anchor_gt_idx,anchor_gt_residue,
|
||||
true_ca_masks,pred_ca_mask,
|
||||
asym_mask,
|
||||
pred_ca_pos):
|
||||
input_mask = AlphaFoldMultimerLoss.calculate_input_mask(true_ca_masks,
|
||||
anchor_gt_idx,anchor_gt_residue,
|
||||
asym_mask,
|
||||
pred_ca_mask)
|
||||
input_mask = torch.squeeze(input_mask,0)
|
||||
pred_ca_pos = torch.squeeze(pred_ca_pos,0)
|
||||
asym_mask = torch.squeeze(asym_mask,0)
|
||||
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx],1,anchor_gt_residue)
|
||||
anchor_pred_pos = pred_ca_pos[asym_mask]
|
||||
r, x = get_optimal_transform(
|
||||
anchor_pred_pos, torch.squeeze(anchor_true_pos,0),
|
||||
mask=input_mask
|
||||
)
|
||||
|
||||
return r, x
|
||||
|
||||
@staticmethod
|
||||
def multi_chain_perm_align(out, batch,permutate_chains=False):
|
||||
"""
|
||||
A class method that first permutate chains in ground truth first
|
||||
before calculating the loss.
|
||||
|
@ -2073,80 +2137,73 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
|
|||
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
|
||||
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
|
||||
"""
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
|
||||
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
|
||||
assert isinstance(labels, list)
|
||||
ca_idx = rc.atom_order["CA"]
|
||||
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
|
||||
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
|
||||
|
||||
true_ca_poses = [
|
||||
l["all_atom_positions"][..., ca_idx, :] for l in labels
|
||||
] # list([nres, 3])
|
||||
|
||||
true_ca_masks = [
|
||||
l["all_atom_mask"][..., ca_idx].long() for l in labels
|
||||
] # list([nres,])
|
||||
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i!=0]
|
||||
|
||||
per_asym_residue_index = {}
|
||||
for cur_asym_id in unique_asym_ids:
|
||||
asym_mask = (batch["asym_id"] == cur_asym_id).bool()
|
||||
per_asym_residue_index[int(cur_asym_id)] = torch.masked_select(batch["residue_index"], asym_mask)
|
||||
feature, ground_truth = batch
|
||||
del batch
|
||||
if permutate_chains:
|
||||
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch)
|
||||
print(f"anchor_gt_asym:{anchor_gt_asym} anchor_pred_asym:{anchor_pred_asym}")
|
||||
best_rmsd = float('inf')
|
||||
best_align = None
|
||||
# First select anchors from predicted structures and ground truths
|
||||
anchor_gt_asym, anchor_pred_asym_ids = get_least_asym_entity_or_longest_length(ground_truth,feature['asym_id'])
|
||||
entity_2_asym_list = AlphaFoldMultimerLoss.get_entity_2_asym_list(ground_truth)
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
|
||||
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
|
||||
assert isinstance(labels, list)
|
||||
del ground_truth
|
||||
anchor_gt_idx = int(anchor_gt_asym) - 1
|
||||
# Then calculate optimal transform by aligning anchors
|
||||
ca_idx = rc.atom_order["CA"]
|
||||
pred_ca_pos = out["final_atom_positions"][..., ca_idx, :] # [bsz, nres, 3]
|
||||
pred_ca_mask = out["final_atom_mask"][..., ca_idx].to(dtype=pred_ca_pos.dtype) # [bsz, nres]
|
||||
|
||||
unique_entity_ids = torch.unique(batch["entity_id"])
|
||||
entity_2_asym_list = {}
|
||||
for cur_ent_id in unique_entity_ids:
|
||||
ent_mask = batch["entity_id"] == cur_ent_id
|
||||
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
|
||||
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
|
||||
asym_mask = (batch["asym_id"] == anchor_pred_asym).bool()
|
||||
anchor_residue_idx = per_asym_residue_index[int(anchor_pred_asym)]
|
||||
anchor_true_pos = torch.index_select(true_ca_poses[anchor_gt_idx], 1, anchor_residue_idx)
|
||||
anchor_pred_pos = pred_ca_pos[0][asym_mask[0]]
|
||||
|
||||
anchor_true_mask = torch.index_select(true_ca_masks[anchor_gt_idx], 1, anchor_residue_idx)
|
||||
anchor_pred_mask = pred_ca_mask[0][asym_mask[0]]
|
||||
|
||||
input_mask = (anchor_true_mask * anchor_pred_mask).bool()
|
||||
r, x = get_optimal_transform(
|
||||
anchor_pred_pos, anchor_true_pos[0],
|
||||
mask=input_mask[0]
|
||||
)
|
||||
del input_mask # just to save memory
|
||||
del anchor_pred_mask
|
||||
del anchor_true_mask
|
||||
gc.collect()
|
||||
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
|
||||
del true_ca_poses
|
||||
gc.collect()
|
||||
align = greedy_align(
|
||||
batch,
|
||||
per_asym_residue_index,
|
||||
unique_asym_ids,
|
||||
entity_2_asym_list,
|
||||
pred_ca_pos,
|
||||
pred_ca_mask,
|
||||
aligned_true_ca_poses,
|
||||
true_ca_masks,
|
||||
)
|
||||
|
||||
del aligned_true_ca_poses, true_ca_masks
|
||||
del r, x
|
||||
true_ca_poses = [
|
||||
l["all_atom_positions"][..., ca_idx, :] for l in labels
|
||||
] # list([nres, 3])
|
||||
true_ca_masks = [
|
||||
l["all_atom_mask"][..., ca_idx].long() for l in labels
|
||||
] # list([nres,])
|
||||
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
|
||||
for candidate_pred_anchor in anchor_pred_asym_ids:
|
||||
asym_mask = (feature["asym_id"] == candidate_pred_anchor).bool()
|
||||
anchor_gt_residue = per_asym_residue_index[int(candidate_pred_anchor)]
|
||||
r, x = AlphaFoldMultimerLoss.calculate_optimal_transform(true_ca_poses,
|
||||
anchor_gt_idx,anchor_gt_residue,
|
||||
true_ca_masks,pred_ca_mask,
|
||||
asym_mask,
|
||||
pred_ca_pos
|
||||
)
|
||||
aligned_true_ca_poses = [ca.to(r.dtype) @ r + x for ca in true_ca_poses] # apply transforms
|
||||
align = greedy_align(
|
||||
feature,
|
||||
per_asym_residue_index,
|
||||
entity_2_asym_list,
|
||||
pred_ca_pos,
|
||||
pred_ca_mask,
|
||||
aligned_true_ca_poses,
|
||||
true_ca_masks,
|
||||
)
|
||||
merged_labels = merge_labels(per_asym_residue_index,labels,align,
|
||||
original_nres=feature['aatype'].shape[-1])
|
||||
rmsd = compute_rmsd(true_atom_pos = merged_labels['all_atom_positions'][..., ca_idx, :].to(r.dtype) @ r + x,
|
||||
pred_atom_pos = pred_ca_pos,
|
||||
atom_mask = (pred_ca_mask * merged_labels['all_atom_mask'][..., ca_idx].long()).bool())
|
||||
if rmsd < best_rmsd:
|
||||
best_rmsd = rmsd
|
||||
best_align = align
|
||||
del r,x
|
||||
del true_ca_masks,aligned_true_ca_poses
|
||||
del pred_ca_pos, pred_ca_mask
|
||||
del anchor_pred_pos, anchor_true_pos
|
||||
gc.collect()
|
||||
print(f"finished multi-chain permutation and final align is {align}")
|
||||
|
||||
else:
|
||||
align = list(enumerate(range(len(labels))))
|
||||
per_asym_residue_index = AlphaFoldMultimerLoss.get_per_asym_residue_index(feature)
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
|
||||
REQUIRED_FEATURES=["all_atom_mask","all_atom_positions"])
|
||||
best_align = list(enumerate(range(len(labels))))
|
||||
|
||||
return best_align, per_asym_residue_index
|
||||
|
||||
|
||||
return align, per_asym_residue_index
|
||||
|
||||
def forward(self, out, features, _return_breakdown=False,permutate_chains=True):
|
||||
def forward(self, out, batch, _return_breakdown=False,permutate_chains=True):
|
||||
"""
|
||||
Overwrite AlphaFoldLoss forward function so that
|
||||
it first compute multi-chain permutation
|
||||
|
@ -2156,22 +2213,25 @@ class AlphaFoldMultimerLoss(AlphaFoldLoss):
|
|||
batch: a pair of input features and its corresponding ground truth structure
|
||||
"""
|
||||
# first check if it is a monomer
|
||||
features, ground_truth = batch
|
||||
del batch
|
||||
is_monomer = len(torch.unique(features['asym_id']))==1 or torch.unique(features['asym_id']).tolist()==[0,1]
|
||||
|
||||
if not is_monomer:
|
||||
permutate_chains = True
|
||||
# first determin which dimension in the tensor to split into individual ground truth labels
|
||||
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(features)
|
||||
|
||||
# Then permutate ground truth chains before calculating the loss
|
||||
align, per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out, features,dim_dict=dim_dict,
|
||||
permutate_chains=permutate_chains)
|
||||
# Then permutate ground truth chains before calculating the loss
|
||||
align,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
|
||||
(features,ground_truth),
|
||||
permutate_chains=permutate_chains)
|
||||
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(features,dim_dict=dim_dict,
|
||||
REQUIRED_FEATURES=[i for i in features.keys() if i in dim_dict])
|
||||
# reorder ground truth labels according to permutation results
|
||||
labels = merge_labels(per_asym_residue_index,labels,align,
|
||||
original_nres=features['aatype'].shape[-1])
|
||||
features.update(labels)
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(ground_truth,
|
||||
REQUIRED_FEATURES=[i for i in ground_truth.keys()])
|
||||
|
||||
# reorder ground truth labels according to permutation results
|
||||
labels = merge_labels(per_asym_residue_index,labels,align,
|
||||
original_nres=features['aatype'].shape[-1])
|
||||
features.update(labels)
|
||||
|
||||
if (not _return_breakdown):
|
||||
cum_loss = self.loss(out, features, _return_breakdown)
|
||||
|
|
|
@ -102,9 +102,10 @@ class TestPermutation(unittest.TestCase):
|
|||
batch['all_atom_mask'] = true_atom_mask
|
||||
|
||||
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
|
||||
aligns,_ = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
|
||||
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,batch,
|
||||
dim_dict,
|
||||
permutate_chains=True)
|
||||
print(f"##### aligns is {aligns}")
|
||||
possible_outcome = [[(0,1),(1,0),(2,3),(3,4),(4,2)],[(0,0),(1,1),(2,3),(3,4),(4,2)]]
|
||||
wrong_outcome = [[(0,1),(1,0),(2,4),(3,2),(4,3)],[(0,0),(1,1),(2,2),(3,3),(4,4)]]
|
||||
self.assertIn(aligns,possible_outcome)
|
||||
|
@ -151,15 +152,15 @@ class TestPermutation(unittest.TestCase):
|
|||
tensor_to_cuda = lambda t: t.to('cuda')
|
||||
batch = tensor_tree_map(tensor_to_cuda,batch)
|
||||
dim_dict = AlphaFoldMultimerLoss.determine_split_dim(batch)
|
||||
aligns,per_asym_residue_index = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
|
||||
aligns = AlphaFoldMultimerLoss.multi_chain_perm_align(out,
|
||||
batch,
|
||||
dim_dict,
|
||||
permutate_chains=True)
|
||||
|
||||
print(f"##### aligns is {aligns}")
|
||||
labels = AlphaFoldMultimerLoss.split_ground_truth_labels(batch,dim_dict=dim_dict,
|
||||
REQUIRED_FEATURES=[i for i in batch.keys() if i in dim_dict])
|
||||
|
||||
labels = merge_labels(per_asym_residue_index,labels,aligns,
|
||||
labels = merge_labels(labels,aligns,
|
||||
original_nres=batch['aatype'].shape[-1])
|
||||
|
||||
self.assertTrue(torch.equal(labels['residue_index'],batch['residue_index']))
|
||||
|
|
|
@ -273,27 +273,29 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
|
|||
return self.model(batch)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
features,gt_features = batch
|
||||
# Log it
|
||||
if(self.ema.device != batch["aatype"].device):
|
||||
self.ema.to(batch["aatype"].device)
|
||||
if(self.ema.device != features["aatype"].device):
|
||||
self.ema.to(features["aatype"].device)
|
||||
|
||||
# Run the model
|
||||
outputs = self(batch)
|
||||
outputs = self(features)
|
||||
|
||||
# Remove the recycling dimension
|
||||
batch = tensor_tree_map(lambda t: t[..., -1], batch)
|
||||
features = tensor_tree_map(lambda t: t[..., -1], features)
|
||||
|
||||
# Compute loss
|
||||
loss, loss_breakdown = self.loss(
|
||||
outputs, batch, _return_breakdown=True
|
||||
outputs, (features,gt_features), _return_breakdown=True
|
||||
)
|
||||
|
||||
# Log it
|
||||
self._log(loss_breakdown, batch, outputs)
|
||||
self._log(loss_breakdown, features, outputs)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
features,gt_features = batch
|
||||
# At the start of validation, load the EMA weights
|
||||
if(self.cached_weights is None):
|
||||
# model.state_dict() contains references to model weights rather
|
||||
|
@ -304,15 +306,15 @@ class OpenFoldMultimerWrapper(OpenFoldWrapper):
|
|||
self.model.load_state_dict(self.ema.state_dict()["params"])
|
||||
|
||||
# Run the model
|
||||
outputs = self(batch)
|
||||
outputs = self(features)
|
||||
|
||||
# Compute loss and other metrics
|
||||
batch["use_clamped_fape"] = 0.
|
||||
features["use_clamped_fape"] = 0.
|
||||
_, loss_breakdown = self.loss(
|
||||
outputs, batch, _return_breakdown=True
|
||||
outputs, (features,gt_features), _return_breakdown=True
|
||||
)
|
||||
|
||||
self._log(loss_breakdown, batch, outputs, train=False)
|
||||
self._log(loss_breakdown, features, outputs, train=False)
|
||||
|
||||
def validation_epoch_end(self, _):
|
||||
# Restore the model weights to normal
|
||||
|
|
Loading…
Reference in New Issue