added support for exporting of t5 to onnx with past_key_values (#10651)
This commit is contained in:
parent
50f4539b82
commit
5c00918681
|
@ -423,6 +423,8 @@ class T5Attention(nn.Module):
|
|||
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
|
||||
int_seq_length = int(seq_length)
|
||||
|
||||
real_seq_length = seq_length
|
||||
|
||||
if past_key_value is not None:
|
||||
|
@ -489,7 +491,7 @@ class T5Attention(nn.Module):
|
|||
# if key and values are already calculated
|
||||
# we want only the last query position bias
|
||||
if past_key_value is not None:
|
||||
position_bias = position_bias[:, :, -seq_length:, :]
|
||||
position_bias = position_bias[:, :, -int_seq_length:, :]
|
||||
|
||||
if mask is not None:
|
||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
||||
|
|
Loading…
Reference in New Issue