From 1b8613acb32a568db8d9b74ee182d43c4f8e9cbb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 16 Dec 2019 09:51:42 +0100 Subject: [PATCH] updating t5 config class --- transformers/configuration_t5.py | 15 ++------------- transformers/tests/modeling_t5_test.py | 2 +- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/transformers/configuration_t5.py b/transformers/configuration_t5.py index 6391cb4180..377a0919d9 100644 --- a/transformers/configuration_t5.py +++ b/transformers/configuration_t5.py @@ -66,7 +66,7 @@ class T5Config(PretrainedConfig): pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP def __init__(self, - vocab_size_or_config_json_file=32128, + vocab_size=32128, n_positions=512, d_model=512, d_kv=64, @@ -79,7 +79,7 @@ class T5Config(PretrainedConfig): initializer_factor=1.0, **kwargs): super(T5Config, self).__init__(**kwargs) - self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1 + self.vocab_size = vocab_size self.n_positions = n_positions self.d_model = d_model self.d_kv = d_kv @@ -91,17 +91,6 @@ class T5Config(PretrainedConfig): self.layer_norm_epsilon = layer_norm_epsilon self.initializer_factor = initializer_factor - if isinstance(vocab_size_or_config_json_file, six.string_types): - with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif not isinstance(vocab_size_or_config_json_file, int): - raise ValueError( - "First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)" - ) - @property def max_position_embeddings(self): return self.n_positions diff --git a/transformers/tests/modeling_t5_test.py b/transformers/tests/modeling_t5_test.py index a539cc868a..c337163375 100644 --- a/transformers/tests/modeling_t5_test.py +++ b/transformers/tests/modeling_t5_test.py @@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester): decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) config = T5Config( - vocab_size_or_config_json_file=self.vocab_size, + vocab_size=self.vocab_size, n_positions=self.n_positions, d_model=self.hidden_size, d_ff=self.d_ff,