transformers/examples/tensorflow/language-modeling-tpu/train_unigram.py

120 lines
3.7 KiB
Python

#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script for training a Unigram tokenizer."""
import argparse
import logging
import datasets
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import Unigram
from tokenizers.trainers import UnigramTrainer
from transformers import AlbertTokenizerFast
logger = logging.getLogger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a unigram tokenizer on the wikitext dataset.")
parser.add_argument(
"--dataset_name",
type=str,
default="wikitext",
help="Name of the training. Explore datasets at: hf.co/datasets.",
)
parser.add_argument(
"--dataset_config", type=str, default="wikitext-103-raw-v1", help="Configuration name of the dataset."
)
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="Batch size during training.",
)
parser.add_argument(
"--vocab_size",
type=int,
default=10048,
help="Size of the desired vocabulary.",
)
parser.add_argument(
"--limit",
default=None,
type=int,
help="Limit the number of shards (used for debugging).",
)
parser.add_argument(
"--export_to_hub",
action="store_true",
)
args = parser.parse_args()
return args
def main(args):
dataset = datasets.load_dataset(args.dataset_name, args.dataset_config, split="train")
if args.limit is not None:
max_train_samples = min(len(dataset), args.limit)
dataset = dataset.select(range(max_train_samples))
logger.info(f"Limiting the dataset to {args.limit} entries.")
def batch_iterator():
for i in range(0, len(dataset), args.batch_size):
yield dataset[i : i + args.batch_size]["text"]
# Prepare the tokenizer.
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence([normalizers.Replace("``", '"'), normalizers.Replace("''", '"')])
tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
# Prepare the trainer.
trainer = UnigramTrainer(
unk_token="<unk>",
special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
vocab_size=args.vocab_size,
)
logger.info("Training the tokenizer.")
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)
logger.info("Tokenizer training complete!")
cls_token_id = tokenizer.token_to_id("[CLS]")
sep_token_id = tokenizer.token_to_id("[SEP]")
tokenizer.post_processor = processors.TemplateProcessing(
single="[CLS]:0 $A:0 [SEP]:0",
pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
special_tokens=[
("[CLS]", cls_token_id),
("[SEP]", sep_token_id),
],
)
tokenizer.decoder = decoders.Metaspace()
if args.export_to_hub:
logger.info("Exporting the trained tokenizer to Hub.")
new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
new_tokenizer.push_to_hub("unigram-tokenizer-dataset")
if __name__ == "__main__":
args = parse_args()
main(args)