Fixed target_mapping preparation for XLNet when batch size > 1 (incl. beam search) (#7267)

This commit is contained in:
guillaume-be 2020-09-21 10:53:52 +02:00 committed by GitHub
parent 4b3e55bdcc
commit 39062d05f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -1313,7 +1313,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping = torch.zeros(
(effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
target_mapping[:, 0, -1] = 1.0
inputs = {
"input_ids": input_ids,