Update: ElectraDiscriminatorPredictions forward. (#5471)
`ElectraDiscriminatorPredictions.forward` should not need `attention_mask`.
This commit is contained in:
parent
13a8588f2d
commit
ef0e9d806c
|
@ -133,7 +133,7 @@ class ElectraDiscriminatorPredictions(nn.Module):
|
|||
self.dense_prediction = nn.Linear(config.hidden_size, 1)
|
||||
self.config = config
|
||||
|
||||
def forward(self, discriminator_hidden_states, attention_mask):
|
||||
def forward(self, discriminator_hidden_states):
|
||||
hidden_states = self.dense(discriminator_hidden_states)
|
||||
hidden_states = get_activation(self.config.hidden_act)(hidden_states)
|
||||
logits = self.dense_prediction(hidden_states).squeeze()
|
||||
|
@ -518,7 +518,7 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
|
|||
)
|
||||
discriminator_sequence_output = discriminator_hidden_states[0]
|
||||
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output, attention_mask)
|
||||
logits = self.discriminator_predictions(discriminator_sequence_output)
|
||||
|
||||
output = (logits,)
|
||||
|
||||
|
|
Loading…
Reference in New Issue