Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Flaubert (#12292)
* adding position_ids buffer to fix the issue simialr to #5664 * adding position-id buffer to address similar issues to #5664
This commit is contained in:
parent
58e999b7e6
commit
5d1a3d135c
|
@ -18,6 +18,7 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
||||
|
@ -140,6 +141,10 @@ class FlaubertModel(XLMModel):
|
|||
super().__init__(config)
|
||||
self.layerdrop = getattr(config, "layerdrop", 0.0)
|
||||
self.pre_norm = getattr(config, "pre_norm", False)
|
||||
if version.parse(torch.__version__) > version.parse("1.6.0"):
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
|
||||
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -198,10 +203,16 @@ class FlaubertModel(XLMModel):
|
|||
# if self.is_decoder and src_enc is not None:
|
||||
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
|
||||
|
||||
# position_ids
|
||||
# Setting the position-ids to the registered buffer in constructor, it helps
|
||||
# when tracing the model without passing position-ids, solves
|
||||
# isues similar to issue #5664
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(slen, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
|
||||
if hasattr(self, "position_ids"):
|
||||
position_ids = self.position_ids[:, :slen]
|
||||
position_ids = position_ids.expand((bs, slen))
|
||||
else:
|
||||
position_ids = torch.arange(slen, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
|
||||
else:
|
||||
assert position_ids.size() == (bs, slen) # (slen, bs)
|
||||
# position_ids = position_ids.transpose(0, 1)
|
||||
|
|
Loading…
Reference in New Issue