Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)
* Fix critical trace warnings to allow ONNX export * Force input to `sqrt` to be float type * Cleanup code * Remove unused import statement * Update model sew * Small refactor Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * Use broadcasting instead of repeat * Implement suggestion Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com> * Match deberta v2 changes in sew_d * Improve code quality * Update code quality * Consistency of small refactor * Match changes in sew_d Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
This commit is contained in:
parent
5d3f037433
commit
d53dffec6e
|
@ -14,11 +14,9 @@
|
|||
# limitations under the License.
|
||||
""" PyTorch DeBERTa-v2 model."""
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module):
|
|||
|
||||
|
||||
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
||||
sign = np.sign(relative_pos)
|
||||
sign = torch.sign(relative_pos)
|
||||
mid = bucket_size // 2
|
||||
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
|
||||
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
|
||||
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
|
||||
abs_pos = torch.where(
|
||||
(relative_pos < mid) & (relative_pos > -mid),
|
||||
torch.tensor(mid - 1).type_as(relative_pos),
|
||||
torch.abs(relative_pos),
|
||||
)
|
||||
log_pos = (
|
||||
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
|
||||
)
|
||||
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
|
||||
return bucket_pos
|
||||
|
||||
|
||||
|
@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
|
|||
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
||||
|
||||
"""
|
||||
q_ids = np.arange(0, query_size)
|
||||
k_ids = np.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
|
||||
q_ids = torch.arange(0, query_size)
|
||||
k_ids = torch.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
|
||||
if bucket_size > 0 and max_position > 0:
|
||||
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
||||
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
|
||||
rel_pos_ids = rel_pos_ids.to(torch.long)
|
||||
rel_pos_ids = rel_pos_ids[:query_size, :]
|
||||
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
||||
return rel_pos_ids
|
||||
|
@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
scale_factor += 1
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
if self.relative_attention:
|
||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||
|
@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
score = 0
|
||||
# content->position
|
||||
if "c2p" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
|
||||
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
|
||||
c2p_att = torch.gather(
|
||||
|
@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
if key_layer.size(-2) != query_layer.size(-2):
|
||||
r_pos = build_relative_position(
|
||||
key_layer.size(-2),
|
||||
|
|
|
@ -194,11 +194,17 @@ def _compute_mask_indices(
|
|||
|
||||
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position
|
||||
def make_log_bucket_position(relative_pos, bucket_size, max_position):
|
||||
sign = np.sign(relative_pos)
|
||||
sign = torch.sign(relative_pos)
|
||||
mid = bucket_size // 2
|
||||
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
|
||||
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
|
||||
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
|
||||
abs_pos = torch.where(
|
||||
(relative_pos < mid) & (relative_pos > -mid),
|
||||
torch.tensor(mid - 1).type_as(relative_pos),
|
||||
torch.abs(relative_pos),
|
||||
)
|
||||
log_pos = (
|
||||
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
|
||||
)
|
||||
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
|
||||
return bucket_pos
|
||||
|
||||
|
||||
|
@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
|
|||
`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
|
||||
|
||||
"""
|
||||
q_ids = np.arange(0, query_size)
|
||||
k_ids = np.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
|
||||
q_ids = torch.arange(0, query_size)
|
||||
k_ids = torch.arange(0, key_size)
|
||||
rel_pos_ids = q_ids[:, None] - k_ids[None, :]
|
||||
if bucket_size > 0 and max_position > 0:
|
||||
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
|
||||
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
|
||||
rel_pos_ids = rel_pos_ids.to(torch.long)
|
||||
rel_pos_ids = rel_pos_ids[:query_size, :]
|
||||
rel_pos_ids = rel_pos_ids.unsqueeze(0)
|
||||
return rel_pos_ids
|
||||
|
@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
scale_factor += 1
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale_factor += 1
|
||||
scale = math.sqrt(query_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
|
||||
if self.relative_attention:
|
||||
rel_embeddings = self.pos_dropout(rel_embeddings)
|
||||
|
@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
score = 0
|
||||
# content->position
|
||||
if "c2p" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
|
||||
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
|
||||
c2p_att = torch.gather(
|
||||
|
@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module):
|
|||
|
||||
# position->content
|
||||
if "p2c" in self.pos_att_type:
|
||||
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
|
||||
scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
|
||||
if key_layer.size(-2) != query_layer.size(-2):
|
||||
r_pos = build_relative_position(
|
||||
key_layer.size(-2),
|
||||
|
|
Loading…
Reference in New Issue