Add ByT5 option to example run_t5_mlm_flax.py (#12634)
* Allow ByT5 type in Flax T5 script * use T5TokenizerFast * change up tokenizer config * model_args * reorder imports * Update run_t5_mlm_flax.py
This commit is contained in:
parent
9da1acaea2
commit
5803a2a7ac
|
@ -42,12 +42,12 @@ from flax.training.common_utils import get_metrics, onehot, shard
|
|||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
FlaxT5ForConditionalGeneration,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
T5Config,
|
||||
T5TokenizerFast,
|
||||
TrainingArguments,
|
||||
is_tensorboard_available,
|
||||
set_seed,
|
||||
|
@ -477,11 +477,11 @@ if __name__ == "__main__":
|
|||
# Load pretrained model and tokenizer
|
||||
|
||||
if model_args.tokenizer_name:
|
||||
tokenizer = T5TokenizerFast.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
elif model_args.model_name_or_path:
|
||||
tokenizer = T5TokenizerFast.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue