Fix matmul inputs dtype (#18585)

This commit is contained in:
Jingya HUANG 2022-08-17 15:59:43 +02:00 committed by GitHub
parent c99e984657
commit 86d0b26d6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 13 deletions

View File

@ -14,7 +14,6 @@
# limitations under the License.
""" PyTorch DeBERTa model."""
import math
from collections.abc import Sequence
from typing import Optional, Tuple, Union
@ -640,8 +639,8 @@ class DisentangledSelfAttention(nn.Module):
qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
qkvb = [None] * 3
q = linear(qkvw[0], qkvb[0], query_states)
k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
q = linear(qkvw[0], qkvb[0], torch.tensor(query_states, dtype=qkvw[0].dtype))
k, v = [linear(qkvw[i], qkvb[i], torch.tensor(hidden_states, dtype=qkvw[i].dtype)) for i in range(1, 3)]
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
@ -650,8 +649,8 @@ class DisentangledSelfAttention(nn.Module):
rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
scale_factor = 1 + len(self.pos_att_type)
scale = math.sqrt(query_layer.size(-1) * scale_factor)
query_layer = query_layer / scale
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
query_layer = query_layer / torch.tensor(scale, dtype=query_layer.dtype)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
@ -711,13 +710,13 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type:
pos_query_layer = self.pos_q_proj(rel_embeddings)
pos_query_layer = self.transpose_for_scores(pos_query_layer)
pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
pos_query_layer /= torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if query_layer.size(-2) != key_layer.size(-2):
r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
else:
r_pos = relative_pos
p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
p2c_att = torch.matmul(key_layer, torch.tensor(pos_query_layer.transpose(-1, -2), dtype=key_layer.dtype))
p2c_att = torch.gather(
p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
).transpose(-1, -2)

View File

@ -717,7 +717,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias(
@ -799,7 +801,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
)
score += c2p_att / scale
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content
if "p2c" in self.pos_att_type:
@ -822,7 +824,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2)
score += p2c_att / scale
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score

View File

@ -791,7 +791,9 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / torch.tensor(
scale, dtype=query_layer.dtype
)
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias(
@ -873,7 +875,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1,
index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
)
score += c2p_att / scale
score += c2p_att / torch.tensor(scale, dtype=c2p_att.dtype)
# position->content
if "p2c" in self.pos_att_type:
@ -896,7 +898,7 @@ class DisentangledSelfAttention(nn.Module):
dim=-1,
index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
).transpose(-1, -2)
score += p2c_att / scale
score += p2c_att / torch.tensor(scale, dtype=p2c_att.dtype)
return score