diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index a4dda9c..24c3356 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -186,7 +186,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): mmcif=mmcif_object, alignment_dir=alignment_dir, chain_id=chain_id, - alignment_index=alignment_index + alignment_index=alignment_index, + seqemb_mode=self.config.seqemb_mode.enabled ) return data @@ -251,6 +252,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): chain_id=chain_id, alignment_index=alignment_index, _structure_index=structure_index, + seqemb_mode=self.config.seqemb_mode.enabled, ) else: raise ValueError("Extension branch missing") @@ -260,6 +262,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): fasta_path=path, alignment_dir=alignment_dir, alignment_index=alignment_index, + seqemb_mode=self.config.seqemb_mode.enabled, ) if(self._output_raw):