Core: Fix copies on main (#29624)

fix fix copies
This commit is contained in:
Younes Belkada 2024-03-13 09:16:59 +01:00 committed by GitHub
parent be3fd8a262
commit 9acce7de1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -454,7 +454,7 @@ class GPTJFlashAttention2(GPTJAttention):
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)