Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Flaubert (#12292)

* adding position_ids buffer to fix the issue simialr to #5664

* adding position-id buffer to address similar issues to #5664
This commit is contained in:
Hamid Shojanazeri 2021-09-01 01:46:58 -07:00 committed by GitHub
parent 58e999b7e6
commit 5d1a3d135c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 3 deletions

View File

@ -18,6 +18,7 @@
import random
import torch
from packaging import version
from torch import nn
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
@ -140,6 +141,10 @@ class FlaubertModel(XLMModel):
super().__init__(config)
self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False)
if version.parse(torch.__version__) > version.parse("1.6.0"):
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
@ -198,10 +203,16 @@ class FlaubertModel(XLMModel):
# if self.is_decoder and src_enc is not None:
# src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
# position_ids
# Setting the position-ids to the registered buffer in constructor, it helps
# when tracing the model without passing position-ids, solves
# isues similar to issue #5664
if position_ids is None:
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
if hasattr(self, "position_ids"):
position_ids = self.position_ids[:, :slen]
position_ids = position_ids.expand((bs, slen))
else:
position_ids = torch.arange(slen, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand((bs, slen))
else:
assert position_ids.size() == (bs, slen) # (slen, bs)
# position_ids = position_ids.transpose(0, 1)