Remove scale when using ds kernel
This commit is contained in:
parent
6ebcd8b4b0
commit
a098576109
|
@ -379,7 +379,8 @@ class Attention(nn.Module):
|
|||
def _prep_qkv(self,
|
||||
q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor,
|
||||
transpose_qkv_dims: bool = True
|
||||
transpose_qkv_dims: bool = True,
|
||||
apply_scale: bool = True
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor
|
||||
]:
|
||||
|
@ -399,7 +400,8 @@ class Attention(nn.Module):
|
|||
k = k.transpose(-2, -3)
|
||||
v = v.transpose(-2, -3)
|
||||
|
||||
q /= math.sqrt(self.c_hidden)
|
||||
if apply_scale:
|
||||
q /= math.sqrt(self.c_hidden)
|
||||
|
||||
return q, k, v
|
||||
|
||||
|
@ -486,7 +488,9 @@ class Attention(nn.Module):
|
|||
|
||||
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
|
||||
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
|
||||
q, k, v = self._prep_qkv(q_x, kv_x, transpose_qkv_dims=not use_deepspeed_evo_attention)
|
||||
q, k, v = self._prep_qkv(q_x, kv_x,
|
||||
transpose_qkv_dims=not use_deepspeed_evo_attention,
|
||||
apply_scale=not use_deepspeed_evo_attention)
|
||||
|
||||
if is_fp16_enabled():
|
||||
use_memory_efficient_kernel = False
|
||||
|
|
|
@ -355,7 +355,7 @@ if __name__ == "__main__":
|
|||
help="""Postfix for output prediction filenames"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_random_seed", type=str, default=None
|
||||
"--data_random_seed", type=int, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_relaxation", action="store_true", default=False,
|
||||
|
|
|
@ -15,9 +15,6 @@
|
|||
"""
|
||||
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
|
||||
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
|
||||
|
||||
Note: Some tests are temporarily disabled while we investigate discrepancies related
|
||||
to using fused attention.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
@ -159,6 +156,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
|||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_deepspeed_evo_attention = False
|
||||
out_repro = model(batch)
|
||||
|
||||
# Enable kernel
|
||||
|
|
Loading…
Reference in New Issue