added support for exporting of t5 to onnx with past_key_values (#10651)

This commit is contained in:
Kiran R 2021-04-23 21:44:20 +05:30 committed by GitHub
parent 50f4539b82
commit 5c00918681
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 1 deletions

View File

@ -423,6 +423,8 @@ class T5Attention(nn.Module):
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
int_seq_length = int(seq_length)
real_seq_length = seq_length
if past_key_value is not None:
@ -489,7 +491,7 @@ class T5Attention(nn.Module):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -seq_length:, :]
position_bias = position_bias[:, :, -int_seq_length:, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)