Fix uninitialized parameter in conformer relative attention. (#18368)
`torch.Tensor` creates an unitialized tensor (as via `torch.empty`), this leads to undeterministic behavior, poor initialization, and nans if you have unlucky init. The paper does not specify the initialization for bias terms, so I guess zero seems like a good choice - no bias initially. `torch.Tensor` is usually populated with zeros, so this fix will be close to the intended behavior: ``` >>> torch.Tensor(100, 100).sum() tensor(0.) >>> torch.Tensor(100, 100).sum() tensor(nan) >>> torch.Tensor(100, 100).sum() tensor(0.) ```
This commit is contained in:
parent
df5e4232f5
commit
68a894a587
|
@ -670,8 +670,8 @@ class Wav2Vec2ConformerSelfAttention(nn.Module):
|
|||
self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size))
|
||||
self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
||||
self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue