fix mt5 config (#8832)
This commit is contained in:
parent
18c32eeb21
commit
36b60ce9e8
|
@ -60,6 +60,8 @@ class MT5Config(PretrainedConfig):
|
|||
testing).
|
||||
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
|
||||
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
|
||||
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
"""
|
||||
model_type = "mt5"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
@ -79,6 +81,7 @@ class MT5Config(PretrainedConfig):
|
|||
initializer_factor=1.0,
|
||||
feed_forward_proj="gated-gelu",
|
||||
is_encoder_decoder=True,
|
||||
use_cache=True,
|
||||
tokenizer_class="T5Tokenizer",
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=0,
|
||||
|
@ -109,6 +112,7 @@ class MT5Config(PretrainedConfig):
|
|||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
self.feed_forward_proj = feed_forward_proj
|
||||
self.use_cache = use_cache
|
||||
|
||||
@property
|
||||
def hidden_size(self):
|
||||
|
|
Loading…
Reference in New Issue