938 lines
40 KiB
Python
Executable File
938 lines
40 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2021 The HuggingFace Team All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
|
text file or a dataset.
|
|
|
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
|
https://huggingface.co/models?filter=fill-mask
|
|
"""
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from dataclasses import asdict, dataclass, field
|
|
from enum import Enum
|
|
from itertools import chain
|
|
|
|
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import flax
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
import optax
|
|
from datasets import load_dataset
|
|
from flax import jax_utils, traverse_util
|
|
from flax.jax_utils import pad_shard_unpad
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard
|
|
from huggingface_hub import HfApi
|
|
from tqdm import tqdm
|
|
|
|
from transformers import (
|
|
CONFIG_MAPPING,
|
|
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
|
AutoConfig,
|
|
AutoTokenizer,
|
|
FlaxAutoModelForMaskedLM,
|
|
HfArgumentParser,
|
|
PreTrainedTokenizerBase,
|
|
TensorType,
|
|
is_tensorboard_available,
|
|
set_seed,
|
|
)
|
|
from transformers.utils import send_example_telemetry
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments:
|
|
output_dir: str = field(
|
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
|
)
|
|
overwrite_output_dir: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Overwrite the content of the output directory. "
|
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
|
)
|
|
},
|
|
)
|
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
|
per_device_train_batch_size: int = field(
|
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
|
)
|
|
per_device_eval_batch_size: int = field(
|
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
|
)
|
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
|
push_to_hub: bool = field(
|
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
|
)
|
|
hub_model_id: str = field(
|
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
|
)
|
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
|
gradient_checkpointing: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.output_dir is not None:
|
|
self.output_dir = os.path.expanduser(self.output_dir)
|
|
|
|
def to_dict(self):
|
|
"""
|
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
|
the token values by removing their value.
|
|
"""
|
|
d = asdict(self)
|
|
for k, v in d.items():
|
|
if isinstance(v, Enum):
|
|
d[k] = v.value
|
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
|
d[k] = [x.value for x in v]
|
|
if k.endswith("_token"):
|
|
d[k] = f"<{k.upper()}>"
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
|
"""
|
|
|
|
model_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
|
)
|
|
},
|
|
)
|
|
model_type: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
|
)
|
|
config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
)
|
|
tokenizer_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
)
|
|
use_fast_tokenizer: bool = field(
|
|
default=True,
|
|
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
|
)
|
|
dtype: Optional[str] = field(
|
|
default="float32",
|
|
metadata={
|
|
"help": (
|
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
|
" `[float32, float16, bfloat16]`."
|
|
)
|
|
},
|
|
)
|
|
token: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
|
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
|
)
|
|
},
|
|
)
|
|
use_auth_token: bool = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
|
},
|
|
)
|
|
trust_remote_code: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
|
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
|
"execute code present on the Hub on your local machine."
|
|
)
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
|
|
dataset_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
|
)
|
|
dataset_config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
|
)
|
|
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
|
validation_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
|
)
|
|
train_ref_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
|
)
|
|
validation_ref_file: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
validation_split_percentage: Optional[int] = field(
|
|
default=5,
|
|
metadata={
|
|
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
|
},
|
|
)
|
|
max_seq_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The maximum total input sequence length after tokenization. Sequences longer "
|
|
"than this will be truncated. Default to the max input length of the model."
|
|
)
|
|
},
|
|
)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
mlm_probability: float = field(
|
|
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
|
)
|
|
pad_to_max_length: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Whether to pad all samples to `max_seq_length`. "
|
|
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
|
)
|
|
},
|
|
)
|
|
line_by_line: bool = field(
|
|
default=False,
|
|
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
|
raise ValueError("Need either a dataset name or a training/validation file.")
|
|
else:
|
|
if self.train_file is not None:
|
|
extension = self.train_file.split(".")[-1]
|
|
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
|
if self.validation_file is not None:
|
|
extension = self.validation_file.split(".")[-1]
|
|
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
|
|
|
|
|
@flax.struct.dataclass
|
|
class FlaxDataCollatorForLanguageModeling:
|
|
"""
|
|
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
|
are not all of the same length.
|
|
|
|
Args:
|
|
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
|
The tokenizer used for encoding the data.
|
|
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
|
The probability with which to (randomly) mask tokens in the input.
|
|
|
|
.. note::
|
|
|
|
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
|
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
|
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
|
argument :obj:`return_special_tokens_mask=True`.
|
|
"""
|
|
|
|
tokenizer: PreTrainedTokenizerBase
|
|
mlm_probability: float = 0.15
|
|
|
|
def __post_init__(self):
|
|
if self.tokenizer.mask_token is None:
|
|
raise ValueError(
|
|
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
|
"You should pass `mlm=False` to train on causal language modeling instead."
|
|
)
|
|
|
|
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
|
# Handle dict or lists with proper padding and conversion to tensor.
|
|
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
|
|
|
# If special token mask has been preprocessed, pop it from the dict.
|
|
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
|
|
|
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
|
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
|
)
|
|
return batch
|
|
|
|
def mask_tokens(
|
|
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
|
"""
|
|
labels = inputs.copy()
|
|
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
|
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
|
special_tokens_mask = special_tokens_mask.astype("bool")
|
|
|
|
probability_matrix[special_tokens_mask] = 0.0
|
|
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
|
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
|
|
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
|
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
|
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
|
|
|
# 10% of the time, we replace masked input tokens with random word
|
|
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
|
indices_random &= masked_indices & ~indices_replaced
|
|
|
|
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
|
inputs[indices_random] = random_words[indices_random]
|
|
|
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
|
return inputs, labels
|
|
|
|
|
|
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
|
|
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
|
|
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
|
|
num_samples = len(samples_idx)
|
|
if drop_last:
|
|
samples_to_remove = num_samples % batch_size
|
|
if samples_to_remove != 0:
|
|
samples_idx = samples_idx[:-samples_to_remove]
|
|
sections_split = num_samples // batch_size
|
|
samples_idx = samples_idx.reshape((sections_split, batch_size))
|
|
else:
|
|
sections_split = math.ceil(num_samples / batch_size)
|
|
samples_idx = np.array_split(samples_idx, sections_split)
|
|
return samples_idx
|
|
|
|
|
|
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
|
summary_writer.scalar("train_time", train_time, step)
|
|
|
|
train_metrics = get_metrics(train_metrics)
|
|
for key, vals in train_metrics.items():
|
|
tag = f"train_{key}"
|
|
for i, val in enumerate(vals):
|
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
|
|
|
|
|
def write_eval_metric(summary_writer, eval_metrics, step):
|
|
for metric_name, value in eval_metrics.items():
|
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
|
|
|
|
|
def main():
|
|
# See all possible arguments in src/transformers/training_args.py
|
|
# or by passing the --help flag to this script.
|
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
# If we pass only one argument to the script and it's the path to a json file,
|
|
# let's parse it to get our arguments.
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
else:
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
if model_args.use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if model_args.token is not None:
|
|
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
|
model_args.token = model_args.use_auth_token
|
|
|
|
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
|
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
|
send_example_telemetry("run_mlm", model_args, data_args, framework="flax")
|
|
|
|
if (
|
|
os.path.exists(training_args.output_dir)
|
|
and os.listdir(training_args.output_dir)
|
|
and training_args.do_train
|
|
and not training_args.overwrite_output_dir
|
|
):
|
|
raise ValueError(
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
"Use --overwrite_output_dir to overcome."
|
|
)
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
level=logging.INFO,
|
|
datefmt="[%X]",
|
|
)
|
|
|
|
# Log on each process the small summary:
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
|
# Set seed before initializing model.
|
|
set_seed(training_args.seed)
|
|
|
|
# Handle the repository creation
|
|
if training_args.push_to_hub:
|
|
# Retrieve of infer repo_name
|
|
repo_name = training_args.hub_model_id
|
|
if repo_name is None:
|
|
repo_name = Path(training_args.output_dir).absolute().name
|
|
# Create repo and retrieve repo_id
|
|
api = HfApi()
|
|
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
|
|
|
|
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
|
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
|
# (the dataset will be downloaded automatically from the datasets Hub).
|
|
#
|
|
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
|
# 'text' is found. You can easily tweak this behavior (see below).
|
|
#
|
|
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
|
# download the dataset.
|
|
if data_args.dataset_name is not None:
|
|
# Downloading and loading a dataset from the hub.
|
|
datasets = load_dataset(
|
|
data_args.dataset_name,
|
|
data_args.dataset_config_name,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
|
|
if "validation" not in datasets.keys():
|
|
datasets["validation"] = load_dataset(
|
|
data_args.dataset_name,
|
|
data_args.dataset_config_name,
|
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
datasets["train"] = load_dataset(
|
|
data_args.dataset_name,
|
|
data_args.dataset_config_name,
|
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
else:
|
|
data_files = {}
|
|
if data_args.train_file is not None:
|
|
data_files["train"] = data_args.train_file
|
|
extension = data_args.train_file.split(".")[-1]
|
|
if data_args.validation_file is not None:
|
|
data_files["validation"] = data_args.validation_file
|
|
extension = data_args.validation_file.split(".")[-1]
|
|
if extension == "txt":
|
|
extension = "text"
|
|
datasets = load_dataset(
|
|
extension,
|
|
data_files=data_files,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
|
|
if "validation" not in datasets.keys():
|
|
datasets["validation"] = load_dataset(
|
|
extension,
|
|
data_files=data_files,
|
|
split=f"train[:{data_args.validation_split_percentage}%]",
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
datasets["train"] = load_dataset(
|
|
extension,
|
|
data_files=data_files,
|
|
split=f"train[{data_args.validation_split_percentage}%:]",
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
)
|
|
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
|
# https://huggingface.co/docs/datasets/loading_datasets.
|
|
|
|
# Load pretrained model and tokenizer
|
|
|
|
# Distributed training:
|
|
# The .from_pretrained methods guarantee that only one local process can concurrently
|
|
# download model & vocab.
|
|
if model_args.config_name:
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.config_name,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
elif model_args.model_name_or_path:
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
else:
|
|
config = CONFIG_MAPPING[model_args.model_type]()
|
|
logger.warning("You are instantiating a new config instance from scratch.")
|
|
|
|
if model_args.tokenizer_name:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.tokenizer_name,
|
|
cache_dir=model_args.cache_dir,
|
|
use_fast=model_args.use_fast_tokenizer,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
elif model_args.model_name_or_path:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
cache_dir=model_args.cache_dir,
|
|
use_fast=model_args.use_fast_tokenizer,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
|
|
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
|
)
|
|
|
|
# Preprocessing the datasets.
|
|
# First we tokenize all the texts.
|
|
if training_args.do_train:
|
|
column_names = datasets["train"].column_names
|
|
else:
|
|
column_names = datasets["validation"].column_names
|
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
|
|
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
|
|
|
if data_args.line_by_line:
|
|
# When using line_by_line, we just tokenize each nonempty line.
|
|
padding = "max_length" if data_args.pad_to_max_length else False
|
|
|
|
def tokenize_function(examples):
|
|
# Remove empty lines
|
|
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
|
return tokenizer(
|
|
examples,
|
|
return_special_tokens_mask=True,
|
|
padding=padding,
|
|
truncation=True,
|
|
max_length=max_seq_length,
|
|
)
|
|
|
|
tokenized_datasets = datasets.map(
|
|
tokenize_function,
|
|
input_columns=[text_column_name],
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
remove_columns=column_names,
|
|
load_from_cache_file=not data_args.overwrite_cache,
|
|
)
|
|
|
|
else:
|
|
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
|
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
|
# efficient when it receives the `special_tokens_mask`.
|
|
def tokenize_function(examples):
|
|
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
|
|
|
tokenized_datasets = datasets.map(
|
|
tokenize_function,
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
remove_columns=column_names,
|
|
load_from_cache_file=not data_args.overwrite_cache,
|
|
)
|
|
|
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
|
# max_seq_length.
|
|
def group_texts(examples):
|
|
# Concatenate all texts.
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
|
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
|
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
|
# customize this part to your needs.
|
|
if total_length >= max_seq_length:
|
|
total_length = (total_length // max_seq_length) * max_seq_length
|
|
# Split by chunks of max_len.
|
|
result = {
|
|
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
|
for k, t in concatenated_examples.items()
|
|
}
|
|
return result
|
|
|
|
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
|
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
|
# might be slower to preprocess.
|
|
#
|
|
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
|
# https://huggingface.co/docs/datasets/process#map
|
|
tokenized_datasets = tokenized_datasets.map(
|
|
group_texts,
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
load_from_cache_file=not data_args.overwrite_cache,
|
|
)
|
|
|
|
# Enable tensorboard only on the master node
|
|
has_tensorboard = is_tensorboard_available()
|
|
if has_tensorboard and jax.process_index() == 0:
|
|
try:
|
|
from flax.metrics.tensorboard import SummaryWriter
|
|
|
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
|
except ImportError as ie:
|
|
has_tensorboard = False
|
|
logger.warning(
|
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Unable to display metrics through TensorBoard because the package is not installed: "
|
|
"Please run pip install tensorboard to enable."
|
|
)
|
|
|
|
# Data collator
|
|
# This one will take care of randomly masking the tokens.
|
|
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
|
|
|
# Initialize our training
|
|
rng = jax.random.PRNGKey(training_args.seed)
|
|
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
|
|
|
if model_args.model_name_or_path:
|
|
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
config=config,
|
|
seed=training_args.seed,
|
|
dtype=getattr(jnp, model_args.dtype),
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
else:
|
|
model = FlaxAutoModelForMaskedLM.from_config(
|
|
config,
|
|
seed=training_args.seed,
|
|
dtype=getattr(jnp, model_args.dtype),
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
|
|
if training_args.gradient_checkpointing:
|
|
model.enable_gradient_checkpointing()
|
|
|
|
# Store some constant
|
|
num_epochs = int(training_args.num_train_epochs)
|
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
|
|
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
|
|
|
# Create learning rate schedule
|
|
warmup_fn = optax.linear_schedule(
|
|
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
|
)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=training_args.learning_rate,
|
|
end_value=0,
|
|
transition_steps=num_train_steps - training_args.warmup_steps,
|
|
)
|
|
linear_decay_lr_schedule_fn = optax.join_schedules(
|
|
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
|
)
|
|
|
|
# We use Optax's "masking" functionality to not apply weight decay
|
|
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
|
# mask boolean with the same structure as the parameters.
|
|
# The mask is True for parameters that should be decayed.
|
|
def decay_mask_fn(params):
|
|
flat_params = traverse_util.flatten_dict(params)
|
|
# find out all LayerNorm parameters
|
|
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
|
|
layer_norm_named_params = {
|
|
layer[-2:]
|
|
for layer_norm_name in layer_norm_candidates
|
|
for layer in flat_params.keys()
|
|
if layer_norm_name in "".join(layer).lower()
|
|
}
|
|
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
|
|
return traverse_util.unflatten_dict(flat_mask)
|
|
|
|
# create adam optimizer
|
|
if training_args.adafactor:
|
|
# We use the default parameters here to initialize adafactor,
|
|
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
|
optimizer = optax.adafactor(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
)
|
|
else:
|
|
optimizer = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=training_args.adam_beta1,
|
|
b2=training_args.adam_beta2,
|
|
eps=training_args.adam_epsilon,
|
|
weight_decay=training_args.weight_decay,
|
|
mask=decay_mask_fn,
|
|
)
|
|
|
|
# Setup train state
|
|
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
|
|
|
# Define gradient update step fn
|
|
def train_step(state, batch, dropout_rng):
|
|
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
|
|
def loss_fn(params):
|
|
labels = batch.pop("labels")
|
|
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
|
|
# compute loss, ignore padded input tokens
|
|
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
|
|
|
# take average
|
|
loss = loss.sum()
|
|
num_labels = label_mask.sum()
|
|
|
|
return loss, num_labels
|
|
|
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
|
(loss, num_labels), grad = grad_fn(state.params)
|
|
num_labels = jax.lax.psum(num_labels, "batch")
|
|
|
|
# true loss = total loss / total samples
|
|
loss = jax.lax.psum(loss, "batch")
|
|
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
|
|
|
# true grad = total grad / total samples
|
|
grad = jax.lax.psum(grad, "batch")
|
|
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
|
new_state = state.apply_gradients(grads=grad)
|
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
|
|
return new_state, metrics, new_dropout_rng
|
|
|
|
# Create parallel version of the train step
|
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
|
# Define eval fn
|
|
def eval_step(params, batch):
|
|
labels = batch.pop("labels")
|
|
|
|
logits = model(**batch, params=params, train=False)[0]
|
|
|
|
# compute loss, ignore padded input tokens
|
|
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
|
|
|
# compute accuracy
|
|
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
|
|
|
# summarize metrics
|
|
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
|
metrics = jax.lax.psum(metrics, axis_name="batch")
|
|
|
|
return metrics
|
|
|
|
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
|
|
|
# Replicate the train state on each device
|
|
state = jax_utils.replicate(state)
|
|
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
for epoch in epochs:
|
|
# ======================== Training ================================
|
|
train_start = time.time()
|
|
train_metrics = []
|
|
|
|
# Create sampling rng
|
|
rng, input_rng = jax.random.split(rng)
|
|
|
|
# Generate an epoch by shuffling sampling indices from the train dataset
|
|
num_train_samples = len(tokenized_datasets["train"])
|
|
# Avoid using jax.numpy here in case of TPU training
|
|
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
|
|
|
# Gather the indexes for creating the batch and do a training step
|
|
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
|
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
|
|
|
# Model forward
|
|
model_inputs = shard(model_inputs.data)
|
|
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
|
train_metrics.append(train_metric)
|
|
|
|
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
|
|
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
|
# Save metrics
|
|
train_metric = jax_utils.unreplicate(train_metric)
|
|
train_time += time.time() - train_start
|
|
if has_tensorboard and jax.process_index() == 0:
|
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
|
|
|
epochs.write(
|
|
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
|
|
f" {train_metric['learning_rate']})"
|
|
)
|
|
|
|
train_metrics = []
|
|
|
|
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
|
# ======================== Evaluating ==============================
|
|
num_eval_samples = len(tokenized_datasets["validation"])
|
|
# Avoid using jax.numpy here in case of TPU training
|
|
eval_samples_idx = np.arange(num_eval_samples)
|
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
|
|
|
eval_metrics = []
|
|
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
|
|
|
# Model forward
|
|
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
|
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
|
)
|
|
eval_metrics.append(metrics)
|
|
|
|
# normalize eval metrics
|
|
eval_metrics = get_metrics(eval_metrics)
|
|
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
|
|
eval_normalizer = eval_metrics.pop("normalizer")
|
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
|
|
|
# Update progress bar
|
|
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
|
|
|
# Save metrics
|
|
if has_tensorboard and jax.process_index() == 0:
|
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
|
|
|
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
|
# save checkpoint after each epoch and push checkpoint to the hub
|
|
if jax.process_index() == 0:
|
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
|
model.save_pretrained(training_args.output_dir, params=params)
|
|
tokenizer.save_pretrained(training_args.output_dir)
|
|
if training_args.push_to_hub:
|
|
api.upload_folder(
|
|
commit_message=f"Saving weights and logs of step {cur_step}",
|
|
folder_path=training_args.output_dir,
|
|
repo_id=repo_id,
|
|
repo_type="model",
|
|
token=training_args.hub_token,
|
|
)
|
|
# Eval after training
|
|
if training_args.do_eval:
|
|
num_eval_samples = len(tokenized_datasets["validation"])
|
|
# Avoid using jax.numpy here in case of TPU training
|
|
eval_samples_idx = np.arange(num_eval_samples)
|
|
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
|
|
|
|
eval_metrics = []
|
|
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
|
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
|
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
|
|
|
# Model forward
|
|
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
|
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
|
|
)
|
|
eval_metrics.append(metrics)
|
|
|
|
# normalize eval metrics
|
|
eval_metrics = get_metrics(eval_metrics)
|
|
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
|
|
eval_normalizer = eval_metrics.pop("normalizer")
|
|
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
|
|
|
try:
|
|
perplexity = math.exp(eval_metrics["loss"])
|
|
except OverflowError:
|
|
perplexity = float("inf")
|
|
eval_metrics["perplexity"] = perplexity
|
|
|
|
if jax.process_index() == 0:
|
|
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
|
path = os.path.join(training_args.output_dir, "eval_results.json")
|
|
with open(path, "w") as f:
|
|
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|