Update VITS modeling to enable ONNX export (#28141)

* Update vits modeling for onnx export compatibility

* fix style

* Update src/transformers/models/vits/modeling_vits.py
This commit is contained in:
Ella Charlaix 2024-01-05 17:52:32 +01:00 committed by GitHub
parent cadf93a6fc
commit 7226f3d2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -1022,7 +1022,7 @@ class VitsAttention(nn.Module):
# Pad along column
x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
x_flat = x.view([batch_heads, length**2 + length * (length - 1)])
x_flat = x.view([batch_heads, length * (2 * length - 1)])
# Add 0's in the beginning that will skew the elements after reshape
x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])