enable cache by default (#9296)

This commit is contained in:
Suraj Patil 2020-12-24 17:47:36 +05:30 committed by GitHub
parent 6189ae9960
commit 2a18b70998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 0 deletions

View File

@ -61,6 +61,9 @@ class BertGenerationConfig(PretrainedConfig):
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
Examples::
@ -95,6 +98,7 @@ class BertGenerationConfig(PretrainedConfig):
eos_token_id=1,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -112,3 +116,4 @@ class BertGenerationConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache

View File

@ -339,6 +339,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: