Fixes for multimer config features and cropping

This commit is contained in:
Christina Floristean 2023-08-16 17:33:12 -04:00
parent 3de188e9d0
commit da5d0e7d30
4 changed files with 56 additions and 13 deletions

View File

@ -156,6 +156,10 @@ def model_config(
elif "multimer" in name:
c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
@ -676,11 +680,57 @@ config = mlc.ConfigDict(
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True,
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16_output": False
},
"data": {
"common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
# "all_crops_all_chains_mask": [],
# "all_crops_all_chains_positions": [],
# "all_crops_all_chains_residue_ids": [],
"assembly_num_chains": [],
"asym_id": [NUM_RES],
"atom14_atom_exists": [NUM_RES, None],
"atom37_atom_exists": [NUM_RES, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"cluster_bias_mask": [NUM_MSA_SEQ],
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
"deletion_mean": [NUM_RES],
"entity_id": [NUM_RES],
"entity_mask": [NUM_RES],
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
# "mem_peak": [],
"msa": [NUM_MSA_SEQ, NUM_RES],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_profile": [NUM_RES, None],
"num_alignments": [],
"num_templates": [],
# "queue_size": [],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"seq_length": [],
"seq_mask": [NUM_RES],
"sym_id": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20,
"unsupervised_features": [
"aatype",
@ -741,7 +791,6 @@ multimer_config_update = mlc.ConfigDict({
"tri_mul_first": True,
"fuse_projection_weights": True
},
"template_pointwise_attention": None, # Not used in Multimer
"c_t": c_t,
"c_z": c_z,
"use_unit_vector": True
@ -785,8 +834,7 @@ multimer_config_update = mlc.ConfigDict({
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5
},
"backbone": None # Not used in Multimer
}
},
"masked_msa": {
"num_classes": 22

View File

@ -77,8 +77,9 @@ def np_example_to_features(
is_multimer: bool = False
):
np_example = dict(np_example)
num_res = int(np_example["seq_length"][0])
seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example:

View File

@ -31,11 +31,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_atom14_masks,
]
if(common_cfg.use_templates):
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
])
return transforms

View File

@ -274,9 +274,8 @@ def _correct_post_merged_feats(
) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging."""
num_res = np_example['aatype'].shape[0]
np_example['seq_length'] = np.asarray(
[num_res] * num_res,
np_example['aatype'].shape[0],
dtype=np.int32
)
np_example['num_alignments'] = np.asarray(