updating t5 config class
This commit is contained in:
parent
7140363e09
commit
1b8613acb3
|
@ -66,7 +66,7 @@ class T5Config(PretrainedConfig):
|
|||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(self,
|
||||
vocab_size_or_config_json_file=32128,
|
||||
vocab_size=32128,
|
||||
n_positions=512,
|
||||
d_model=512,
|
||||
d_kv=64,
|
||||
|
@ -79,7 +79,7 @@ class T5Config(PretrainedConfig):
|
|||
initializer_factor=1.0,
|
||||
**kwargs):
|
||||
super(T5Config, self).__init__(**kwargs)
|
||||
self.vocab_size = vocab_size_or_config_json_file if isinstance(vocab_size_or_config_json_file, int) else -1
|
||||
self.vocab_size = vocab_size
|
||||
self.n_positions = n_positions
|
||||
self.d_model = d_model
|
||||
self.d_kv = d_kv
|
||||
|
@ -91,17 +91,6 @@ class T5Config(PretrainedConfig):
|
|||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_factor = initializer_factor
|
||||
|
||||
if isinstance(vocab_size_or_config_json_file, six.string_types):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
self.__dict__[key] = value
|
||||
elif not isinstance(vocab_size_or_config_json_file, int):
|
||||
raise ValueError(
|
||||
"First argument must be either a vocabulary size (int)"
|
||||
"or the path to a pretrained model config file (str)"
|
||||
)
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.n_positions
|
||||
|
|
|
@ -93,7 +93,7 @@ class T5ModelTest(CommonTestCases.CommonModelTester):
|
|||
decoder_lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
|
||||
|
||||
config = T5Config(
|
||||
vocab_size_or_config_json_file=self.vocab_size,
|
||||
vocab_size=self.vocab_size,
|
||||
n_positions=self.n_positions,
|
||||
d_model=self.hidden_size,
|
||||
d_ff=self.d_ff,
|
||||
|
|
Loading…
Reference in New Issue