Finalize lm examples (#8188)

* Finish the cleanup of the language-modeling examples

* Update main README

* Apply suggestions from code review

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Apply suggestions from code review

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>

* Propagate changes

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Sylvain Gugger 2020-10-30 14:20:18 -04:00 committed by GitHub
parent 089cc1015e
commit cdc48ce92d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 461 additions and 93 deletions

View File

@ -5,7 +5,9 @@ Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
Here is the list of all our examples:
- **grouped by task** (all official examples work for multiple models)
- with information on whether they are **built on top of `Trainer`/`TFTrainer`** (if not, they still work, they might just lack some features),
- with information on whether they are **built on top of `Trainer`/`TFTrainer`** (if not, they still work, they might
just lack some features),
- whether or not they leverage the [🤗 Datasets](https://github.com/huggingface/datasets) library.
- links to **Colab notebooks** to walk through the scripts and run them easily,
- links to **Cloud deployments** to be able to deploy large-scale trainings in the Cloud with little to no setup.
@ -31,19 +33,19 @@ git checkout tags/v3.4.0
## The Big Table of Tasks
| Task | Example datasets | Trainer support | TFTrainer support | Colab
|---|---|:---:|:---:|:---:|
| [**`language-modeling`**](https://github.com/huggingface/transformers/tree/master/examples/language-modeling) | Raw text | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)
| [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/trainer/01_text_classification.ipynb)
| [**`token-classification`**](https://github.com/huggingface/transformers/tree/master/examples/token-classification) | CoNLL NER | ✅ | ✅ | -
| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
| [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | -
| [**`text-generation`**](https://github.com/huggingface/transformers/tree/master/examples/text-generation) | - | n/a | n/a | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)
| [**`distillation`**](https://github.com/huggingface/transformers/tree/master/examples/distillation) | All | - | - | -
| [**`summarization`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | CNN/Daily Mail | ✅ | - | -
| [**`translation`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | WMT | ✅ | - | -
| [**`bertology`**](https://github.com/huggingface/transformers/tree/master/examples/bertology) | - | - | - | -
| [**`adversarial`**](https://github.com/huggingface/transformers/tree/master/examples/adversarial) | HANS | ✅ | - | -
| Task | Example datasets | Trainer support | TFTrainer support | 🤗 Datasets | Colab
|---|---|:---:|:---:|:---:|:---:|
| [**`language-modeling`**](https://github.com/huggingface/transformers/tree/master/examples/language-modeling) | Raw text | ✅ | - | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/01_how_to_train.ipynb)
| [**`text-classification`**](https://github.com/huggingface/transformers/tree/master/examples/text-classification) | GLUE, XNLI | ✅ | ✅ | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/huggingface/notebooks/blob/master/examples/text_classification.ipynb)
| [**`token-classification`**](https://github.com/huggingface/transformers/tree/master/examples/token-classification) | CoNLL NER | ✅ | ✅ | - | -
| [**`multiple-choice`**](https://github.com/huggingface/transformers/tree/master/examples/multiple-choice) | SWAG, RACE, ARC | ✅ | ✅ | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ViktorAlm/notebooks/blob/master/MPC_GPU_Demo_for_TF_and_PT.ipynb)
| [**`question-answering`**](https://github.com/huggingface/transformers/tree/master/examples/question-answering) | SQuAD | ✅ | ✅ | - | -
| [**`text-generation`**](https://github.com/huggingface/transformers/tree/master/examples/text-generation) | - | n/a | n/a | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/master/notebooks/02_how_to_generate.ipynb)
| [**`distillation`**](https://github.com/huggingface/transformers/tree/master/examples/distillation) | All | - | - | - | -
| [**`summarization`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | CNN/Daily Mail | ✅ | - | - | -
| [**`translation`**](https://github.com/huggingface/transformers/tree/master/examples/seq2seq) | WMT | ✅ | - | - | -
| [**`bertology`**](https://github.com/huggingface/transformers/tree/master/examples/bertology) | - | - | - | - | -
| [**`adversarial`**](https://github.com/huggingface/transformers/tree/master/examples/adversarial) | HANS | ✅ | - | - | -
<br>

View File

@ -1,16 +1,19 @@
## Language model training
Based on the script [`run_language_modeling.py`](https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_language_modeling.py).
Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2,
ALBERT, BERT, DistilBERT, RoBERTa, XLNet... GPT and GPT-2 are trained or fine-tuned using a causal language modeling
(CLM) loss while ALBERT, BERT, DistilBERT and RoBERTa are trained or fine-tuned using a masked language modeling (MLM)
loss. XLNet uses permutation language modeling (PLM), you can find more information about the differences between those
objectives in our [model summary](https://huggingface.co/transformers/model_summary.html).
Fine-tuning (or training from scratch) the library models for language modeling on a text dataset for GPT, GPT-2, BERT, DistilBERT and RoBERTa. GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT, DistilBERT and RoBERTa
are fine-tuned using a masked language modeling (MLM) loss.
These scripts leverage the 🤗 Datasets library and the Trainer API. You can easily customize them to your needs if you
need extra processing on your datasets.
Before running the following example, you should get a file that contains text on which the language model will be
trained or fine-tuned. A good example of such text is the [WikiText-2 dataset](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/).
**Note:** The old script `run_language_modeling.py` is still available
[here](https://github.com/huggingface/transformers/blob/master/examples/contrib/legacy/language-modeling/run_language_modeling.py).
We will refer to two different files: `$TRAIN_FILE`, which contains text for training, and `$TEST_FILE`, which contains
text that will be used for evaluation.
The following examples, will run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own
text files for training and validation. We give examples of both below.
### GPT-2/GPT and causal language modeling
@ -18,66 +21,99 @@ The following example fine-tunes GPT-2 on WikiText-2. We're using the raw WikiTe
the tokenization). The loss here is that of causal language modeling.
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python run_language_modeling.py \
--output_dir=output \
--model_type=gpt2 \
--model_name_or_path=gpt2 \
python run_clm.py \
--model_name_or_path gpt2 \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE
--output_dir /tmp/test-clm
```
This takes about half an hour to train on a single K80 GPU and about one minute for the evaluation to run. It reaches
a score of ~20 perplexity once fine-tuned on the dataset.
To run on your own training and validation files, use the following command:
```bash
python run_clm.py \
--model_name_or_path gpt2 \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--do_train \
--do_eval \
--output_dir /tmp/test-clm
```
### RoBERTa/BERT/DistilBERT and masked language modeling
The following example fine-tunes RoBERTa on WikiText-2. Here too, we're using the raw WikiText-2. The loss is different
as BERT/RoBERTa have a bidirectional mechanism; we're therefore using the same loss that was used during their
pre-training: masked language modeling.
In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore, converge
slightly slower (over-fitting takes more epochs).
We use the `--mlm` flag so that the script may change its loss function.
If using whole-word masking, use both the`--mlm` and `--wwm` flags.
In accordance to the RoBERTa paper, we use dynamic masking rather than static masking. The model may, therefore,
converge slightly slower (over-fitting takes more epochs).
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
python run_mlm.py \
--model_name_or_path roberta-base \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm \
--whole_word_mask
--output_dir /tmp/test-mlm
```
For Chinese models, it's same with English model with only `--mlm`. If using whole-word masking, we need to generate a reference files, because it's char level.
To run on your own training and validation files, use the following command:
**Q :** Why ref file ?
```bash
python run_clm.py \
--model_name_or_path roberta-base \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--do_train \
--do_eval \
--output_dir /tmp/test-clm
```
**A :** Suppose we have a Chinese sentence like : `我喜欢你` The original Chinese-BERT will tokenize it as `['我','喜','欢','你']` in char level.
Actually, `喜欢` is a whole word. For whole word mask proxy, We need res like `['我','喜','##欢','你']`.
So we need a ref file to tell model which pos of BERT original token should be added `##`.
### Whole word masking
The BERT authors released a new version of BERT using Whole Word Masking in May 2019. Instead of masking randomly
selected tokens (which may be aprt of words), they mask randomly selected words (masking all the tokens corresponding
to that word). This technique has been refined for Chinese in [this paper](https://arxiv.org/abs/1906.08101).
To fine-tune a model using whole word masking, use the following script:
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--do_eval \
--output_dir /tmp/test-mlm-wwm
```
For Chinese models, we need to generate a reference files (which requires the ltp library), because it's tokenized at
the character level.
**Q :** Why a reference file?
**A :** Suppose we have a Chinese sentence like: `我喜欢你` The original Chinese-BERT will tokenize it as
`['我','喜','欢','你']` (character level). But `喜欢` is a whole word. For whole word masking proxy, we need a result
like `['我','喜','##欢','你']`, so we need a reference file to tell the model which position of the BERT original token
should be added `##`.
**Q :** Why LTP ?
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
They use LTP, so if we want to fine-tune their model, we need LTP.
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT.
It works well on so many Chines Task like CLUE (Chinese GLUE). They use LTP, so if we want to fine-tune their model,
we need LTP.
Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt.
You need to check to `3.2.0` for `run_chinese_ref.py`. And the code could be found in `examples/contrib`.
You need to create a separate enviromnent with this version of Transformers to run the `run_chinese_ref.py` script that
will create the reference files. The script is in `examples/contrib`. Once in the proper enviromnent, run the
following:
```bash
@ -87,31 +123,25 @@ export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt
python examples/contrib/run_chinese_ref.py \
--file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE \
--bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH
--file_name=path_to_train_or_eval_file \
--ltp=path_to_ltp_tokenizer \
--bert=path_to_bert_tokenizer \
--save_path=path_to_reference_file
```
Now Chinese Ref is only supported by `LineByLineWithRefDataset` Class, so we need add `line_by_line` flag:
Then you can run the script like this:
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
export REF_FILE=/path/to/ref.txt
python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--train_ref_file path_to_train_chinese_ref_file \
--validation_ref_file path_to_validation_chinese_ref_file \
--do_train \
--train_data_file=$TRAIN_FILE \
--chinese_ref_file=$REF_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm \
--line_by_line \
--whole_word_mask
--output_dir /tmp/test-mlm-wwm
```
### XLNet and permutation language modeling
@ -126,15 +156,26 @@ context length for permutation language modeling.
The `--max_span_length` flag may also be used to limit the length of a span of masked tokens used
for permutation language modeling.
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
Here is how to fine-tun XLNet on wikitext-2:
python run_language_modeling.py \
--output_dir=output \
```bash
python run_plm.py \
--model_name_or_path=xlnet-base-cased \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--output_dir /tmp/test-plm
```
To fine-tune it on your own training and validation file, run:
```bash
python run_plm.py \
--model_name_or_path=xlnet-base-cased \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--do_train \
--do_eval \
--output_dir /tmp/test-plm
```

View File

@ -175,10 +175,10 @@ def main():
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
# behavior (see below)
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.

View File

@ -0,0 +1,325 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=masked-lm
"""
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
import json
import logging
import math
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from datasets import Dataset, load_dataset
import transformers
from transformers import (
CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForWholeWordMask,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import is_main_process
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
def __post_init__(self):
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
def add_chinese_references(dataset, ref_file):
with open(ref_file, "r", encoding="utf-8") as f:
refs = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
assert len(dataset) == len(refs)
dataset_dict = {c: dataset[c] for c in dataset.column_names}
dataset_dict["chinese_ref"] = refs
return Dataset.from_dict(dataset_dict)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty."
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
# Set seed before initializing model.
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.html.
# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if model_args.model_name_or_path:
model = AutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForMaskedLM.from_config(config)
model.resize_token_embeddings(len(tokenizer))
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
def tokenize_function(examples):
# Remove empty lines
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
return tokenizer(examples["text"], truncation=True, max_length=data_args.max_seq_length)
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)
# Add the chinese references if provided
if data_args.train_ref_file is not None:
tokenized_datasets["train"] = add_chinese_references(tokenized_datasets["train"], data_args.train_ref_file)
if data_args.valid_ref_file is not None:
tokenized_datasets["validation"] = add_chinese_references(
tokenized_datasets["validation"], data_args.validation_ref_file
)
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
tokenizer=tokenizer,
data_collator=data_collator,
)
# Training
if training_args.do_train:
trainer.train(
model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
)
trainer.save_model() # Saves the tokenizer too for easy upload
# Evaluation
results = {}
if training_args.do_eval:
logger.info("*** Evaluate ***")
eval_output = trainer.evaluate()
perplexity = math.exp(eval_output["eval_loss"])
results["perplexity"] = perplexity
output_eval_file = os.path.join(training_args.output_dir, "eval_results_mlm_wwm.txt")
if trainer.is_world_process_zero():
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key, value in results.items():
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")
return results
def _mp_fn(index):
# For xla_spawn (TPUs)
main()
if __name__ == "__main__":
main()

View File

@ -96,7 +96,7 @@ class DataTrainingArguments:
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
"than this will be truncated. Default to the max input length of the model."
},
)
preprocessing_num_workers: Optional[int] = field(
@ -172,10 +172,10 @@ def main():
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
# behavior (see below)
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.

View File

@ -177,7 +177,7 @@ def main():
set_seed(training_args.seed)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named

View File

@ -190,10 +190,10 @@ def main():
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
# behavior (see below)
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.