Fix matmul inputs dtype (#18585)
This commit is contained in:
parent
c99e984657
commit
86d0b26d6c
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch DeBERTa model."""
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
|
||||
qkvb = [None] * 3
|
||||
|
||||
q = linear(qkvw[0], qkvb[0], query_states)
|
||||
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
|
||||
q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
|
||||
k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
|
||||
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
|
||||
|
||||
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
|
||||
|
@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module):
|
|||
rel_att = None
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
scale_factor = 1 + len(self.pos_att_type)
|
||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
||||
query_layer = query_layer / scale
|
||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||||
if self.relative_attention:
|
||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||
|
@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module):
|
|||
if "p2c" in self.pos_att_type:
|
||||
pos_query_layer = self.pos_q_proj(rel_embeddings)
|
||||
pos_query_layer = self.transpose_for_scores(pos_query_layer)
|
||||
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
||||
pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
if query_layer.size(-2) != key_layer.size(-2):
|
||||
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
|
||||
else:
|
||||
r_pos = relative_pos
|
||||
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
|
||||
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
|
||||
p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
|
||||
p2c_att = torch.gather(
|
||||
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
|
||||
).transpose(-1, -2)
|
||||
|
|
|
@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
|
||||
scale, dtype=query_layer.dtype
|
||||
)
|
||||
if self.relative_attention:
|
||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||
rel_att = self.disentangled_attention_bias(
|
||||
|
@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
dim=-1,
|
||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
||||
)
|
||||
score += c2p_att / scale
|
||||
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
|
||||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type:
|
||||
|
@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
dim=-1,
|
||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
||||
).transpose(-1, -2)
|
||||
score += p2c_att / scale
|
||||
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
|
||||
|
||||
return score
|
||||
|
||||
|
|
|
@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module):
|
|||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
|
||||
scale, dtype=query_layer.dtype
|
||||
)
|
||||
if self.relative_attention:
|
||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||
rel_att = self.disentangled_attention_bias(
|
||||
|
@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
dim=-1,
|
||||
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
|
||||
)
|
||||
score += c2p_att / scale
|
||||
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
|
||||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type:
|
||||
|
@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
dim=-1,
|
||||
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
|
||||
).transpose(-1, -2)
|
||||
score += p2c_att / scale
|
||||
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
|
||||
|
||||
return score
|
||||
|
||||
|
|
Loading…
Reference in New Issue