Remove scale when using ds kernel
This commit is contained in:
parent
6ebcd8b4b0
commit
a098576109
|
@ -379,7 +379,8 @@ class Attention(nn.Module):
|
||||||
def _prep_qkv(self,
|
def _prep_qkv(self,
|
||||||
q_x: torch.Tensor,
|
q_x: torch.Tensor,
|
||||||
kv_x: torch.Tensor,
|
kv_x: torch.Tensor,
|
||||||
transpose_qkv_dims: bool = True
|
transpose_qkv_dims: bool = True,
|
||||||
|
apply_scale: bool = True
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor, torch.Tensor, torch.Tensor
|
torch.Tensor, torch.Tensor, torch.Tensor
|
||||||
]:
|
]:
|
||||||
|
@ -399,7 +400,8 @@ class Attention(nn.Module):
|
||||||
k = k.transpose(-2, -3)
|
k = k.transpose(-2, -3)
|
||||||
v = v.transpose(-2, -3)
|
v = v.transpose(-2, -3)
|
||||||
|
|
||||||
q /= math.sqrt(self.c_hidden)
|
if apply_scale:
|
||||||
|
q /= math.sqrt(self.c_hidden)
|
||||||
|
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
|
@ -486,7 +488,9 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
|
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
|
||||||
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
|
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
|
||||||
q, k, v = self._prep_qkv(q_x, kv_x, transpose_qkv_dims=not use_deepspeed_evo_attention)
|
q, k, v = self._prep_qkv(q_x, kv_x,
|
||||||
|
transpose_qkv_dims=not use_deepspeed_evo_attention,
|
||||||
|
apply_scale=not use_deepspeed_evo_attention)
|
||||||
|
|
||||||
if is_fp16_enabled():
|
if is_fp16_enabled():
|
||||||
use_memory_efficient_kernel = False
|
use_memory_efficient_kernel = False
|
||||||
|
|
|
@ -355,7 +355,7 @@ if __name__ == "__main__":
|
||||||
help="""Postfix for output prediction filenames"""
|
help="""Postfix for output prediction filenames"""
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data_random_seed", type=str, default=None
|
"--data_random_seed", type=int, default=None
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip_relaxation", action="store_true", default=False,
|
"--skip_relaxation", action="store_true", default=False,
|
||||||
|
|
|
@ -15,9 +15,6 @@
|
||||||
"""
|
"""
|
||||||
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
|
Unit tests to compare components of OpenFold run with the DeepSpeed memory-efficient
|
||||||
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
|
attention kernel, DS4Sci_EvoformerAttention vs. a stock PyTorch attention implementation.
|
||||||
|
|
||||||
Note: Some tests are temporarily disabled while we investigate discrepancies related
|
|
||||||
to using fused attention.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -159,6 +156,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
model = compare_utils.get_global_pretrained_openfold()
|
model = compare_utils.get_global_pretrained_openfold()
|
||||||
|
model.globals.use_deepspeed_evo_attention = False
|
||||||
out_repro = model(batch)
|
out_repro = model(batch)
|
||||||
|
|
||||||
# Enable kernel
|
# Enable kernel
|
||||||
|
|
Loading…
Reference in New Issue