Update: ElectraDiscriminatorPredictions forward. (#5471)

`ElectraDiscriminatorPredictions.forward` should not need `attention_mask`.
This commit is contained in:
Shen 2020-07-02 12:57:33 -05:00 committed by GitHub
parent 13a8588f2d
commit ef0e9d806c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -133,7 +133,7 @@ class ElectraDiscriminatorPredictions(nn.Module):
self.dense_prediction = nn.Linear(config.hidden_size, 1)
self.config = config
def forward(self, discriminator_hidden_states, attention_mask):
def forward(self, discriminator_hidden_states):
hidden_states = self.dense(discriminator_hidden_states)
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
logits = self.dense_prediction(hidden_states).squeeze()
@ -518,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
)
discriminator_sequence_output = discriminator_hidden_states[0]
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)
logits = self.discriminator_predictions(discriminator_sequence_output)
output = (logits,)