added type hints (#19015)

This commit is contained in:
Partho 2022-09-14 17:28:05 +05:30 committed by GitHub
parent fc21c9be62
commit 77ea35b93a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 18 deletions

View File

@ -468,12 +468,12 @@ class FSMTEncoder(nn.Module):
def forward(
self,
input_ids,
attention_mask=None,
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
"""
Args:
@ -669,18 +669,18 @@ class FSMTDecoder(nn.Module):
def forward(
self,
input_ids,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
input_ids: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_padding_mask: torch.Tensor,
decoder_padding_mask: torch.Tensor,
decoder_causal_mask: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
):
"""
Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,