updating t5 config class

This commit is contained in:
thomwolf 2019-12-16 09:51:42 +01:00
parent 7140363e09
commit 1b8613acb3
2 changed files with 3 additions and 14 deletions

View File

@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self,
vocab_size_or_config_json_file=32128,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
initializer_factor=1.0,
**kwargs):
super(T5Config, self).__init__(**kwargs)
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
if isinstance(vocab_size_or_config_json_file, six.string_types):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif not isinstance(vocab_size_or_config_json_file, int):
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@property
def max_position_embeddings(self):
return self.n_positions

View File

@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = T5Config(
vocab_size_or_config_json_file=self.vocab_size,
vocab_size=self.vocab_size,
n_positions=self.n_positions,
d_model=self.hidden_size,
d_ff=self.d_ff,