Fixes for multimer config features and cropping
This commit is contained in:
parent
3de188e9d0
commit
da5d0e7d30
|
@ -156,6 +156,10 @@ def model_config(
|
|||
elif "multimer" in name:
|
||||
c.update(multimer_config_update.copy_and_resolve_references())
|
||||
|
||||
# Not used in multimer
|
||||
del c.model.template.template_pointwise_attention
|
||||
del c.loss.fape.backbone
|
||||
|
||||
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
|
||||
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
|
||||
#c.model.input_embedder.num_msa = 252
|
||||
|
@ -676,11 +680,57 @@ config = mlc.ConfigDict(
|
|||
multimer_config_update = mlc.ConfigDict({
|
||||
"globals": {
|
||||
"is_multimer": True,
|
||||
"bfloat16": False, # TODO: Change to True when implemented
|
||||
"bfloat16": False, # TODO: Change to True when implemented
|
||||
"bfloat16_output": False
|
||||
},
|
||||
"data": {
|
||||
"common": {
|
||||
"feat": {
|
||||
"aatype": [NUM_RES],
|
||||
"all_atom_mask": [NUM_RES, None],
|
||||
"all_atom_positions": [NUM_RES, None, None],
|
||||
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
|
||||
# "all_crops_all_chains_mask": [],
|
||||
# "all_crops_all_chains_positions": [],
|
||||
# "all_crops_all_chains_residue_ids": [],
|
||||
"assembly_num_chains": [],
|
||||
"asym_id": [NUM_RES],
|
||||
"atom14_atom_exists": [NUM_RES, None],
|
||||
"atom37_atom_exists": [NUM_RES, None],
|
||||
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
|
||||
"cluster_bias_mask": [NUM_MSA_SEQ],
|
||||
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
|
||||
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
|
||||
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
|
||||
"deletion_mean": [NUM_RES],
|
||||
"entity_id": [NUM_RES],
|
||||
"entity_mask": [NUM_RES],
|
||||
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
|
||||
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
|
||||
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
|
||||
# "mem_peak": [],
|
||||
"msa": [NUM_MSA_SEQ, NUM_RES],
|
||||
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
|
||||
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
|
||||
"msa_profile": [NUM_RES, None],
|
||||
"num_alignments": [],
|
||||
"num_templates": [],
|
||||
# "queue_size": [],
|
||||
"residue_index": [NUM_RES],
|
||||
"residx_atom14_to_atom37": [NUM_RES, None],
|
||||
"residx_atom37_to_atom14": [NUM_RES, None],
|
||||
"resolution": [],
|
||||
"seq_length": [],
|
||||
"seq_mask": [NUM_RES],
|
||||
"sym_id": [NUM_RES],
|
||||
"target_feat": [NUM_RES, None],
|
||||
"template_aatype": [NUM_TEMPLATES, NUM_RES],
|
||||
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
|
||||
"template_all_atom_positions": [
|
||||
NUM_TEMPLATES, NUM_RES, None, None,
|
||||
],
|
||||
"true_msa": [NUM_MSA_SEQ, NUM_RES]
|
||||
},
|
||||
"max_recycling_iters": 20,
|
||||
"unsupervised_features": [
|
||||
"aatype",
|
||||
|
@ -741,7 +791,6 @@ multimer_config_update = mlc.ConfigDict({
|
|||
"tri_mul_first": True,
|
||||
"fuse_projection_weights": True
|
||||
},
|
||||
"template_pointwise_attention": None, # Not used in Multimer
|
||||
"c_t": c_t,
|
||||
"c_z": c_z,
|
||||
"use_unit_vector": True
|
||||
|
@ -785,8 +834,7 @@ multimer_config_update = mlc.ConfigDict({
|
|||
"clamp_distance": 30.0,
|
||||
"loss_unit_distance": 20.0,
|
||||
"weight": 0.5
|
||||
},
|
||||
"backbone": None # Not used in Multimer
|
||||
}
|
||||
},
|
||||
"masked_msa": {
|
||||
"num_classes": 22
|
||||
|
|
|
@ -77,8 +77,9 @@ def np_example_to_features(
|
|||
is_multimer: bool = False
|
||||
):
|
||||
np_example = dict(np_example)
|
||||
|
||||
num_res = int(np_example["seq_length"][0])
|
||||
|
||||
seq_length = np_example["seq_length"]
|
||||
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
|
||||
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
|
||||
|
||||
if "deletion_matrix_int" in np_example:
|
||||
|
|
|
@ -31,11 +31,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
|
|||
data_transforms.make_atom14_masks,
|
||||
]
|
||||
|
||||
if(common_cfg.use_templates):
|
||||
transforms.extend([
|
||||
data_transforms.make_pseudo_beta("template_"),
|
||||
])
|
||||
|
||||
return transforms
|
||||
|
||||
|
||||
|
|
|
@ -274,9 +274,8 @@ def _correct_post_merged_feats(
|
|||
) -> Mapping[str, np.ndarray]:
|
||||
"""Adds features that need to be computed/recomputed post merging."""
|
||||
|
||||
num_res = np_example['aatype'].shape[0]
|
||||
np_example['seq_length'] = np.asarray(
|
||||
[num_res] * num_res,
|
||||
np_example['aatype'].shape[0],
|
||||
dtype=np.int32
|
||||
)
|
||||
np_example['num_alignments'] = np.asarray(
|
||||
|
|
Loading…
Reference in New Issue