enable cache by default (#9296)
This commit is contained in:
parent
6189ae9960
commit
2a18b70998
|
@ -61,6 +61,9 @@ class BertGenerationConfig(PretrainedConfig):
|
|||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||
<https://arxiv.org/abs/2009.13658>`__.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if ``config.is_decoder=True``.
|
||||
|
||||
Examples::
|
||||
|
||||
|
@ -95,6 +98,7 @@ class BertGenerationConfig(PretrainedConfig):
|
|||
eos_token_id=1,
|
||||
gradient_checkpointing=False,
|
||||
position_embedding_type="absolute",
|
||||
use_cache=True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
@ -112,3 +116,4 @@ class BertGenerationConfig(PretrainedConfig):
|
|||
self.layer_norm_eps = layer_norm_eps
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.use_cache = use_cache
|
||||
|
|
|
@ -339,6 +339,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
|
|||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
|
|
Loading…
Reference in New Issue