Update VITS modeling to enable ONNX export (#28141)
* Update vits modeling for onnx export compatibility * fix style * Update src/transformers/models/vits/modeling_vits.py
This commit is contained in:
parent
cadf93a6fc
commit
7226f3d2b0
|
@ -1022,7 +1022,7 @@ class VitsAttention(nn.Module):
|
|||
|
||||
# Pad along column
|
||||
x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
|
||||
x_flat = x.view([batch_heads, length**2 + length * (length - 1)])
|
||||
x_flat = x.view([batch_heads, length * (2 * length - 1)])
|
||||
|
||||
# Add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
|
||||
|
|
Loading…
Reference in New Issue