Add ByT5 option to example run_t5_mlm_flax.py (#12634)

* Allow ByT5 type in Flax T5 script

* use T5TokenizerFast

* change up tokenizer config

* model_args

* reorder imports

* Update run_t5_mlm_flax.py
This commit is contained in:
Nick Doiron 2021-07-13 08:39:57 -04:00 committed by GitHub
parent 9da1acaea2
commit 5803a2a7ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 3 deletions

View File

@ -42,12 +42,12 @@ from flax.training.common_utils import get_metrics, onehot, shard
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
BatchEncoding,
FlaxT5ForConditionalGeneration,
HfArgumentParser,
PreTrainedTokenizerBase,
T5Config,
T5TokenizerFast,
TrainingArguments,
is_tensorboard_available,
set_seed,
@ -477,11 +477,11 @@ if __name__ == "__main__":
# Load pretrained model and tokenizer
if model_args.tokenizer_name:
tokenizer = T5TokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path:
tokenizer = T5TokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else: