Fixed target_mapping preparation for XLNet when batch size > 1 (incl. beam search) (#7267)
This commit is contained in:
parent
4b3e55bdcc
commit
39062d05f0
|
@ -1313,7 +1313,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
target_mapping = torch.zeros(
|
||||
(effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
|
||||
)
|
||||
target_mapping[0, 0, -1] = 1.0
|
||||
target_mapping[:, 0, -1] = 1.0
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
|
|
Loading…
Reference in New Issue