fix mt5 config (#8832)

This commit is contained in:
Patrick von Platen 2020-11-28 19:50:49 +01:00 committed by GitHub
parent 18c32eeb21
commit 36b60ce9e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -60,6 +60,8 @@ class MT5Config(PretrainedConfig):
testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
"""
model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"]
@ -79,6 +81,7 @@ class MT5Config(PretrainedConfig):
initializer_factor=1.0,
feed_forward_proj="gated-gelu",
is_encoder_decoder=True,
use_cache=True,
tokenizer_class="T5Tokenizer",
tie_word_embeddings=False,
pad_token_id=0,
@ -109,6 +112,7 @@ class MT5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache
@property
def hidden_size(self):