parent
be3fd8a262
commit
9acce7de1c
|
@ -454,7 +454,7 @@ class GPTJFlashAttention2(GPTJAttention):
|
|||
attention_mask (`torch.Tensor`):
|
||||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
||||
position of padding tokens and 1 for the position of non-padding tokens.
|
||||
dropout (`int`, *optional*):
|
||||
dropout (`float`):
|
||||
Attention dropout
|
||||
softmax_scale (`float`, *optional*):
|
||||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
||||
|
|
Loading…
Reference in New Issue