From 5d1a3d135cb23228d68bdf795133ce7ac5e2ef24 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Wed, 1 Sep 2021 01:46:58 -0700 Subject: [PATCH] Fix for the issue of device-id getting hardcoded for position-ids during Tracing for Flaubert (#12292) * adding position_ids buffer to fix the issue simialr to #5664 * adding position-id buffer to address similar issues to #5664 --- .../models/flaubert/modeling_flaubert.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 5c0826f014..161929db82 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -18,6 +18,7 @@ import random import torch +from packaging import version from torch import nn from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward @@ -140,6 +141,10 @@ class FlaubertModel(XLMModel): super().__init__(config) self.layerdrop = getattr(config, "layerdrop", 0.0) self.pre_norm = getattr(config, "pre_norm", False) + if version.parse(torch.__version__) > version.parse("1.6.0"): + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -198,10 +203,16 @@ class FlaubertModel(XLMModel): # if self.is_decoder and src_enc is not None: # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None] - # position_ids + # Setting the position-ids to the registered buffer in constructor, it helps + # when tracing the model without passing position-ids, solves + # isues similar to issue #5664 if position_ids is None: - position_ids = torch.arange(slen, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0).expand((bs, slen)) + if hasattr(self, "position_ids"): + position_ids = self.position_ids[:, :slen] + position_ids = position_ids.expand((bs, slen)) + else: + position_ids = torch.arange(slen, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).expand((bs, slen)) else: assert position_ids.size() == (bs, slen) # (slen, bs) # position_ids = position_ids.transpose(0, 1)