326 lines
13 KiB
Python
326 lines
13 KiB
Python
# coding=utf-8
|
|
# Copyright 2019-present, the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Training the distilled model.
|
|
Supported architectures include: BERT -> DistilBERT, RoBERTa -> DistilRoBERTa, GPT2 -> DistilGPT2.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import pickle
|
|
import shutil
|
|
|
|
import numpy as np
|
|
import torch
|
|
from distiller import Distiller
|
|
from lm_seqs_dataset import LmSeqsDataset
|
|
|
|
from transformers import (
|
|
BertConfig,
|
|
BertForMaskedLM,
|
|
BertTokenizer,
|
|
DistilBertConfig,
|
|
DistilBertForMaskedLM,
|
|
DistilBertTokenizer,
|
|
GPT2Config,
|
|
GPT2LMHeadModel,
|
|
GPT2Tokenizer,
|
|
RobertaConfig,
|
|
RobertaForMaskedLM,
|
|
RobertaTokenizer,
|
|
)
|
|
from utils import git_log, init_gpu_params, logger, set_seed
|
|
|
|
|
|
MODEL_CLASSES = {
|
|
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
|
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
|
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
|
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
|
}
|
|
|
|
|
|
def sanity_checks(args):
|
|
"""
|
|
A bunch of args sanity checks to perform even starting...
|
|
"""
|
|
assert (args.mlm and args.alpha_mlm > 0.0) or (not args.mlm and args.alpha_mlm == 0.0)
|
|
assert (args.alpha_mlm > 0.0 and args.alpha_clm == 0.0) or (args.alpha_mlm == 0.0 and args.alpha_clm > 0.0)
|
|
if args.mlm:
|
|
assert os.path.isfile(args.token_counts)
|
|
assert (args.student_type in ["roberta", "distilbert"]) and (args.teacher_type in ["roberta", "bert"])
|
|
else:
|
|
assert (args.student_type in ["gpt2"]) and (args.teacher_type in ["gpt2"])
|
|
|
|
assert args.teacher_type == args.student_type or (
|
|
args.student_type == "distilbert" and args.teacher_type == "bert"
|
|
)
|
|
assert os.path.isfile(args.student_config)
|
|
if args.student_pretrained_weights is not None:
|
|
assert os.path.isfile(args.student_pretrained_weights)
|
|
|
|
if args.freeze_token_type_embds:
|
|
assert args.student_type in ["roberta"]
|
|
|
|
assert args.alpha_ce >= 0.0
|
|
assert args.alpha_mlm >= 0.0
|
|
assert args.alpha_clm >= 0.0
|
|
assert args.alpha_mse >= 0.0
|
|
assert args.alpha_cos >= 0.0
|
|
assert args.alpha_ce + args.alpha_mlm + args.alpha_clm + args.alpha_mse + args.alpha_cos > 0.0
|
|
|
|
|
|
def freeze_pos_embeddings(student, args):
|
|
if args.student_type == "roberta":
|
|
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
|
|
elif args.student_type == "gpt2":
|
|
student.transformer.wpe.weight.requires_grad = False
|
|
|
|
|
|
def freeze_token_type_embeddings(student, args):
|
|
if args.student_type == "roberta":
|
|
student.roberta.embeddings.token_type_embeddings.weight.requires_grad = False
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Training")
|
|
parser.add_argument("--force", action="store_true", help="Overwrite dump_path if it already exists.")
|
|
|
|
parser.add_argument(
|
|
"--dump_path", type=str, required=True, help="The output directory (log, checkpoints, parameters, etc.)"
|
|
)
|
|
parser.add_argument(
|
|
"--data_file",
|
|
type=str,
|
|
required=True,
|
|
help="The binarized file (tokenized + tokens_to_ids) and grouped by sequence.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--student_type",
|
|
type=str,
|
|
choices=["distilbert", "roberta", "gpt2"],
|
|
required=True,
|
|
help="The student type (DistilBERT, RoBERTa).",
|
|
)
|
|
parser.add_argument("--student_config", type=str, required=True, help="Path to the student configuration.")
|
|
parser.add_argument(
|
|
"--student_pretrained_weights", default=None, type=str, help="Load student initialization checkpoint."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--teacher_type", choices=["bert", "roberta", "gpt2"], required=True, help="Teacher type (BERT, RoBERTa)."
|
|
)
|
|
parser.add_argument("--teacher_name", type=str, required=True, help="The teacher model.")
|
|
|
|
parser.add_argument("--temperature", default=2.0, type=float, help="Temperature for the softmax temperature.")
|
|
parser.add_argument(
|
|
"--alpha_ce", default=0.5, type=float, help="Linear weight for the distillation loss. Must be >=0."
|
|
)
|
|
parser.add_argument(
|
|
"--alpha_mlm",
|
|
default=0.0,
|
|
type=float,
|
|
help="Linear weight for the MLM loss. Must be >=0. Should be used in conjunction with `mlm` flag.",
|
|
)
|
|
parser.add_argument("--alpha_clm", default=0.5, type=float, help="Linear weight for the CLM loss. Must be >=0.")
|
|
parser.add_argument("--alpha_mse", default=0.0, type=float, help="Linear weight of the MSE loss. Must be >=0.")
|
|
parser.add_argument(
|
|
"--alpha_cos", default=0.0, type=float, help="Linear weight of the cosine embedding loss. Must be >=0."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--mlm", action="store_true", help="The LM step: MLM or CLM. If `mlm` is True, the MLM is used over CLM."
|
|
)
|
|
parser.add_argument(
|
|
"--mlm_mask_prop",
|
|
default=0.15,
|
|
type=float,
|
|
help="Proportion of tokens for which we need to make a prediction.",
|
|
)
|
|
parser.add_argument("--word_mask", default=0.8, type=float, help="Proportion of tokens to mask out.")
|
|
parser.add_argument("--word_keep", default=0.1, type=float, help="Proportion of tokens to keep.")
|
|
parser.add_argument("--word_rand", default=0.1, type=float, help="Proportion of tokens to randomly replace.")
|
|
parser.add_argument(
|
|
"--mlm_smoothing",
|
|
default=0.7,
|
|
type=float,
|
|
help="Smoothing parameter to emphasize more rare tokens (see XLM, similar to word2vec).",
|
|
)
|
|
parser.add_argument("--token_counts", type=str, help="The token counts in the data_file for MLM.")
|
|
|
|
parser.add_argument(
|
|
"--restrict_ce_to_mask",
|
|
action="store_true",
|
|
help="If true, compute the distillation loss only the [MLM] prediction distribution.",
|
|
)
|
|
parser.add_argument(
|
|
"--freeze_pos_embs",
|
|
action="store_true",
|
|
help="Freeze positional embeddings during distillation. For student_type in ['roberta', 'gpt2'] only.",
|
|
)
|
|
parser.add_argument(
|
|
"--freeze_token_type_embds",
|
|
action="store_true",
|
|
help="Freeze token type embeddings during distillation if existent. For student_type in ['roberta'] only.",
|
|
)
|
|
|
|
parser.add_argument("--n_epoch", type=int, default=3, help="Number of pass on the whole dataset.")
|
|
parser.add_argument("--batch_size", type=int, default=5, help="Batch size (for each process).")
|
|
parser.add_argument(
|
|
"--group_by_size",
|
|
action="store_false",
|
|
help="If true, group sequences that have similar length into the same batch. Default is true.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--gradient_accumulation_steps",
|
|
type=int,
|
|
default=50,
|
|
help="Gradient accumulation for larger training batches.",
|
|
)
|
|
parser.add_argument("--warmup_prop", default=0.05, type=float, help="Linear warmup proportion.")
|
|
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
|
parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.")
|
|
parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="Epsilon for Adam optimizer.")
|
|
parser.add_argument("--max_grad_norm", default=5.0, type=float, help="Max gradient norm.")
|
|
parser.add_argument("--initializer_range", default=0.02, type=float, help="Random initialization range.")
|
|
|
|
parser.add_argument(
|
|
"--fp16",
|
|
action="store_true",
|
|
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
|
)
|
|
parser.add_argument(
|
|
"--fp16_opt_level",
|
|
type=str,
|
|
default="O1",
|
|
help=(
|
|
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
|
|
"See details at https://nvidia.github.io/apex/amp.html"
|
|
),
|
|
)
|
|
parser.add_argument("--n_gpu", type=int, default=1, help="Number of GPUs in the node.")
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="Distributed training - Local rank")
|
|
parser.add_argument("--seed", type=int, default=56, help="Random seed")
|
|
|
|
parser.add_argument("--log_interval", type=int, default=500, help="Tensorboard logging interval.")
|
|
parser.add_argument("--checkpoint_interval", type=int, default=4000, help="Checkpoint interval.")
|
|
args = parser.parse_args()
|
|
sanity_checks(args)
|
|
|
|
# ARGS #
|
|
init_gpu_params(args)
|
|
set_seed(args)
|
|
if args.is_master:
|
|
if os.path.exists(args.dump_path):
|
|
if not args.force:
|
|
raise ValueError(
|
|
f"Serialization dir {args.dump_path} already exists, but you have not precised wheter to overwrite"
|
|
" itUse `--force` if you want to overwrite it"
|
|
)
|
|
else:
|
|
shutil.rmtree(args.dump_path)
|
|
|
|
if not os.path.exists(args.dump_path):
|
|
os.makedirs(args.dump_path)
|
|
logger.info(f"Experiment will be dumped and logged in {args.dump_path}")
|
|
|
|
# SAVE PARAMS #
|
|
logger.info(f"Param: {args}")
|
|
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
|
|
json.dump(vars(args), f, indent=4)
|
|
git_log(args.dump_path)
|
|
|
|
student_config_class, student_model_class, _ = MODEL_CLASSES[args.student_type]
|
|
teacher_config_class, teacher_model_class, teacher_tokenizer_class = MODEL_CLASSES[args.teacher_type]
|
|
|
|
# TOKENIZER #
|
|
tokenizer = teacher_tokenizer_class.from_pretrained(args.teacher_name)
|
|
special_tok_ids = {}
|
|
for tok_name, tok_symbol in tokenizer.special_tokens_map.items():
|
|
idx = tokenizer.all_special_tokens.index(tok_symbol)
|
|
special_tok_ids[tok_name] = tokenizer.all_special_ids[idx]
|
|
logger.info(f"Special tokens {special_tok_ids}")
|
|
args.special_tok_ids = special_tok_ids
|
|
args.max_model_input_size = tokenizer.max_model_input_sizes[args.teacher_name]
|
|
|
|
# DATA LOADER #
|
|
logger.info(f"Loading data from {args.data_file}")
|
|
with open(args.data_file, "rb") as fp:
|
|
data = pickle.load(fp)
|
|
|
|
if args.mlm:
|
|
logger.info(f"Loading token counts from {args.token_counts} (already pre-computed)")
|
|
with open(args.token_counts, "rb") as fp:
|
|
counts = pickle.load(fp)
|
|
|
|
token_probs = np.maximum(counts, 1) ** -args.mlm_smoothing
|
|
for idx in special_tok_ids.values():
|
|
token_probs[idx] = 0.0 # do not predict special tokens
|
|
token_probs = torch.from_numpy(token_probs)
|
|
else:
|
|
token_probs = None
|
|
|
|
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
|
|
logger.info("Data loader created.")
|
|
|
|
# STUDENT #
|
|
logger.info(f"Loading student config from {args.student_config}")
|
|
stu_architecture_config = student_config_class.from_pretrained(args.student_config)
|
|
stu_architecture_config.output_hidden_states = True
|
|
|
|
if args.student_pretrained_weights is not None:
|
|
logger.info(f"Loading pretrained weights from {args.student_pretrained_weights}")
|
|
student = student_model_class.from_pretrained(args.student_pretrained_weights, config=stu_architecture_config)
|
|
else:
|
|
student = student_model_class(stu_architecture_config)
|
|
|
|
if args.n_gpu > 0:
|
|
student.to(f"cuda:{args.local_rank}")
|
|
logger.info("Student loaded.")
|
|
|
|
# TEACHER #
|
|
teacher = teacher_model_class.from_pretrained(args.teacher_name, output_hidden_states=True)
|
|
if args.n_gpu > 0:
|
|
teacher.to(f"cuda:{args.local_rank}")
|
|
logger.info(f"Teacher loaded from {args.teacher_name}.")
|
|
|
|
# FREEZING #
|
|
if args.freeze_pos_embs:
|
|
freeze_pos_embeddings(student, args)
|
|
if args.freeze_token_type_embds:
|
|
freeze_token_type_embeddings(student, args)
|
|
|
|
# SANITY CHECKS #
|
|
assert student.config.vocab_size == teacher.config.vocab_size
|
|
assert student.config.hidden_size == teacher.config.hidden_size
|
|
assert student.config.max_position_embeddings == teacher.config.max_position_embeddings
|
|
if args.mlm:
|
|
assert token_probs.size(0) == stu_architecture_config.vocab_size
|
|
|
|
# DISTILLER #
|
|
torch.cuda.empty_cache()
|
|
distiller = Distiller(
|
|
params=args, dataset=train_lm_seq_dataset, token_probs=token_probs, student=student, teacher=teacher
|
|
)
|
|
distiller.train()
|
|
logger.info("Let's go get some drinks.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|