Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)

* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
iiLaurens 2022-08-11 15:54:43 +02:00 committed by GitHub
parent 5d3f037433
commit d53dffec6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 24 deletions

View File

@ -14,11 +14,9 @@
# limitations under the License.
""" PyTorch DeBERTa-v2 model."""
import math
from collections.abc import Sequence
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module):
def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos)
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
abs_pos = torch.where(
(relative_pos < mid) & (relative_pos > -mid),
torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos
@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = np.arange(0, query_size)
k_ids = np.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content
if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),

View File

@ -194,11 +194,17 @@ def _compute_mask_indices(
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position
def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos)
sign = torch.sign(relative_pos)
mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
abs_pos = torch.where(
(relative_pos < mid) & (relative_pos > -mid),
torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos
@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
"""
q_ids = np.arange(0, query_size)
k_ids = np.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
q_ids = torch.arange(0, query_size)
k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids
@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0
# content->position
if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather(
@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content
if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position(
key_layer.size(-2),