add BartConfig.force_bos_token_to_be_generated (#6526)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Sam Shleifer 2020-08-18 19:15:50 -04:00 committed by GitHub
parent 974bb4af26
commit 1529bf9680
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 15 deletions

View File

@ -95,6 +95,8 @@ BART_CONFIG_ARGS_DOC = r"""
for SequenceClassification
is_encoder_decoder (:obj:`int`, optional, defaults to True):
True
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only true for `bart-large-cnn`.
"""
@ -137,6 +139,7 @@ class BartConfig(PretrainedConfig):
normalize_embedding=True,
static_position_embeddings=False,
add_bias_logits=False,
force_bos_token_to_be_generated=False,
**common_kwargs
):
r"""
@ -195,6 +198,8 @@ class BartConfig(PretrainedConfig):
# pos embedding offset
self.extra_pos_embeddings = self.pad_token_id + 1
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads

View File

@ -1073,23 +1073,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
}
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1:
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
self._force_token_ids_generation(logits, self.config.bos_token_id)
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits
def _force_token_ids_generation(self, scores, token_ids) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
if isinstance(token_ids, int):
token_ids = [token_ids]
all_but_token_ids_mask = torch.tensor(
[x for x in range(self.config.vocab_size) if x not in token_ids],
dtype=torch.long,
device=next(self.parameters()).device,
)
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
scores[:, all_but_token_ids_mask] = -float("inf")
def _force_token_ids_generation(self, scores, token_id) -> None:
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
@staticmethod
def _reorder_cache(past, beam_idx):

View File

@ -47,7 +47,7 @@ class MarianMTModel(BartForConditionalGeneration):
"""
def adjust_logits_during_generation(self, logits, cur_len, max_length):
logits[:, self.config.pad_token_id] = float("-inf")
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
self._force_token_ids_generation(logits, self.config.eos_token_id)
return logits

View File

@ -484,7 +484,7 @@ class BartModelIntegrationTests(unittest.TestCase):
self.assertFalse(model.config.is_valid_mbart())
tok = BartTokenizer.from_pretrained("facebook/bart-large")
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
).to(torch_device)