added type hints (#19015)
This commit is contained in:
parent
fc21c9be62
commit
77ea35b93a
|
@ -468,12 +468,12 @@ class FSMTEncoder(nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -669,18 +669,18 @@ class FSMTDecoder(nn.Module):
|
|||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
encoder_hidden_states,
|
||||
encoder_padding_mask,
|
||||
decoder_padding_mask,
|
||||
decoder_causal_mask,
|
||||
head_mask=None,
|
||||
cross_attn_head_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
input_ids: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_padding_mask: torch.Tensor,
|
||||
decoder_padding_mask: torch.Tensor,
|
||||
decoder_causal_mask: torch.Tensor,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
output_hidden_states: bool = False,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
|
||||
|
|
Loading…
Reference in New Issue