add BartConfig.force_bos_token_to_be_generated (#6526)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
974bb4af26
commit
1529bf9680
|
@ -95,6 +95,8 @@ BART_CONFIG_ARGS_DOC = r"""
|
|||
for SequenceClassification
|
||||
is_encoder_decoder (:obj:`int`, optional, defaults to True):
|
||||
True
|
||||
force_bos_token_to_be_generated (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to force BOS token to be generated at step 1 (after ``decoder_start_token_id``), only true for `bart-large-cnn`.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -137,6 +139,7 @@ class BartConfig(PretrainedConfig):
|
|||
normalize_embedding=True,
|
||||
static_position_embeddings=False,
|
||||
add_bias_logits=False,
|
||||
force_bos_token_to_be_generated=False,
|
||||
**common_kwargs
|
||||
):
|
||||
r"""
|
||||
|
@ -195,6 +198,8 @@ class BartConfig(PretrainedConfig):
|
|||
# pos embedding offset
|
||||
self.extra_pos_embeddings = self.pad_token_id + 1
|
||||
|
||||
self.force_bos_token_to_be_generated = force_bos_token_to_be_generated
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
|
|
@ -1073,23 +1073,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
|
|||
}
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
if cur_len == 1:
|
||||
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
|
||||
self._force_token_ids_generation(logits, self.config.bos_token_id)
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
elif cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
||||
def _force_token_ids_generation(self, scores, token_ids) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0"""
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
all_but_token_ids_mask = torch.tensor(
|
||||
[x for x in range(self.config.vocab_size) if x not in token_ids],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
|
||||
scores[:, all_but_token_ids_mask] = -float("inf")
|
||||
def _force_token_ids_generation(self, scores, token_id) -> None:
|
||||
"""force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))"""
|
||||
scores[:, [x for x in range(self.config.vocab_size) if x != token_id]] = -float("inf")
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past, beam_idx):
|
||||
|
|
|
@ -47,7 +47,7 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||
"""
|
||||
|
||||
def adjust_logits_during_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.pad_token_id] = float("-inf")
|
||||
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
self._force_token_ids_generation(logits, self.config.eos_token_id)
|
||||
return logits
|
||||
|
|
|
@ -484,7 +484,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||
self.assertFalse(model.config.is_valid_mbart())
|
||||
tok = BartTokenizer.from_pretrained("facebook/bart-large")
|
||||
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off power to tens of thousands of homes and businesses in the state."
|
||||
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
|
||||
dct = tok.batch_encode_plus(
|
||||
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
|
||||
).to(torch_device)
|
||||
|
|
Loading…
Reference in New Issue