454 lines
18 KiB
Python
454 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright 2018 CMU and The HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Bertology: this script shows how you can explore the internals of the models in the library to:
|
|
- compute the entropy of the head attentions
|
|
- compute the importance of each head
|
|
- prune (remove) the low importance head.
|
|
Some parts of this script are adapted from the code of Michel et al. (http://arxiv.org/abs/1905.10650)
|
|
which is available at https://github.com/pmichel31415/are-16-heads-really-better-than-1
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
from datetime import datetime
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, SequentialSampler, Subset
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from tqdm import tqdm
|
|
|
|
import transformers
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
GlueDataset,
|
|
default_data_collator,
|
|
glue_compute_metrics,
|
|
glue_output_modes,
|
|
glue_processors,
|
|
set_seed,
|
|
)
|
|
from transformers.trainer_utils import is_main_process
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def entropy(p):
|
|
"""Compute the entropy of a probability distribution"""
|
|
plogp = p * torch.log(p)
|
|
plogp[p == 0] = 0
|
|
return -plogp.sum(dim=-1)
|
|
|
|
|
|
def print_2d_tensor(tensor):
|
|
"""Print a 2D tensor"""
|
|
logger.info("lv, h >\t" + "\t".join(f"{x + 1}" for x in range(len(tensor))))
|
|
for row in range(len(tensor)):
|
|
if tensor.dtype != torch.long:
|
|
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:.5f}" for x in tensor[row].cpu().data))
|
|
else:
|
|
logger.info(f"layer {row + 1}:\t" + "\t".join(f"{x:d}" for x in tensor[row].cpu().data))
|
|
|
|
|
|
def compute_heads_importance(
|
|
args, model, eval_dataloader, compute_entropy=True, compute_importance=True, head_mask=None, actually_pruned=False
|
|
):
|
|
"""This method shows how to compute:
|
|
- head attention entropy
|
|
- head importance scores according to http://arxiv.org/abs/1905.10650
|
|
"""
|
|
# Prepare our tensors
|
|
n_layers, n_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
|
head_importance = torch.zeros(n_layers, n_heads).to(args.device)
|
|
attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
|
|
|
|
if head_mask is None:
|
|
head_mask = torch.ones(n_layers, n_heads).to(args.device)
|
|
|
|
head_mask.requires_grad_(requires_grad=True)
|
|
# If actually pruned attention multi-head, set head mask to None to avoid shape mismatch
|
|
if actually_pruned:
|
|
head_mask = None
|
|
|
|
preds = None
|
|
labels = None
|
|
tot_tokens = 0.0
|
|
|
|
for step, inputs in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
|
|
for k, v in inputs.items():
|
|
inputs[k] = v.to(args.device)
|
|
|
|
# Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
|
|
outputs = model(**inputs, head_mask=head_mask)
|
|
loss, logits, all_attentions = (
|
|
outputs[0],
|
|
outputs[1],
|
|
outputs[-1],
|
|
) # Loss and logits are the first, attention the last
|
|
loss.backward() # Backpropagate to populate the gradients in the head mask
|
|
|
|
if compute_entropy:
|
|
for layer, attn in enumerate(all_attentions):
|
|
masked_entropy = entropy(attn.detach()) * inputs["attention_mask"].float().unsqueeze(1)
|
|
attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()
|
|
|
|
if compute_importance:
|
|
head_importance += head_mask.grad.abs().detach()
|
|
|
|
# Also store our logits/labels if we want to compute metrics afterwards
|
|
if preds is None:
|
|
preds = logits.detach().cpu().numpy()
|
|
labels = inputs["labels"].detach().cpu().numpy()
|
|
else:
|
|
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
|
labels = np.append(labels, inputs["labels"].detach().cpu().numpy(), axis=0)
|
|
|
|
tot_tokens += inputs["attention_mask"].float().detach().sum().data
|
|
|
|
# Normalize
|
|
attn_entropy /= tot_tokens
|
|
head_importance /= tot_tokens
|
|
# Layerwise importance normalization
|
|
if not args.dont_normalize_importance_by_layer:
|
|
exponent = 2
|
|
norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
|
|
head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
|
|
|
|
if not args.dont_normalize_global_importance:
|
|
head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
|
|
|
|
# Print/save matrices
|
|
np.save(os.path.join(args.output_dir, "attn_entropy.npy"), attn_entropy.detach().cpu().numpy())
|
|
np.save(os.path.join(args.output_dir, "head_importance.npy"), head_importance.detach().cpu().numpy())
|
|
|
|
logger.info("Attention entropies")
|
|
print_2d_tensor(attn_entropy)
|
|
logger.info("Head importance scores")
|
|
print_2d_tensor(head_importance)
|
|
logger.info("Head ranked by importance scores")
|
|
head_ranks = torch.zeros(head_importance.numel(), dtype=torch.long, device=args.device)
|
|
head_ranks[head_importance.view(-1).sort(descending=True)[1]] = torch.arange(
|
|
head_importance.numel(), device=args.device
|
|
)
|
|
head_ranks = head_ranks.view_as(head_importance)
|
|
print_2d_tensor(head_ranks)
|
|
|
|
return attn_entropy, head_importance, preds, labels
|
|
|
|
|
|
def mask_heads(args, model, eval_dataloader):
|
|
"""This method shows how to mask head (set some heads to zero), to test the effect on the network,
|
|
based on the head importance scores, as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
|
"""
|
|
_, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy=False)
|
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
|
original_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
|
logger.info("Pruning: original score: %f, threshold: %f", original_score, original_score * args.masking_threshold)
|
|
|
|
new_head_mask = torch.ones_like(head_importance)
|
|
num_to_mask = max(1, int(new_head_mask.numel() * args.masking_amount))
|
|
|
|
current_score = original_score
|
|
while current_score >= original_score * args.masking_threshold:
|
|
head_mask = new_head_mask.clone() # save current head mask
|
|
# heads from least important to most - keep only not-masked heads
|
|
head_importance[head_mask == 0.0] = float("Inf")
|
|
current_heads_to_mask = head_importance.view(-1).sort()[1]
|
|
|
|
if len(current_heads_to_mask) <= num_to_mask:
|
|
break
|
|
|
|
# mask heads
|
|
current_heads_to_mask = current_heads_to_mask[:num_to_mask]
|
|
logger.info("Heads to mask: %s", str(current_heads_to_mask.tolist()))
|
|
new_head_mask = new_head_mask.view(-1)
|
|
new_head_mask[current_heads_to_mask] = 0.0
|
|
new_head_mask = new_head_mask.view_as(head_mask)
|
|
new_head_mask = new_head_mask.clone().detach()
|
|
print_2d_tensor(new_head_mask)
|
|
|
|
# Compute metric and head importance again
|
|
_, head_importance, preds, labels = compute_heads_importance(
|
|
args, model, eval_dataloader, compute_entropy=False, head_mask=new_head_mask
|
|
)
|
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
|
current_score = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
|
logger.info(
|
|
"Masking: current score: %f, remaining heads %d (%.1f percents)",
|
|
current_score,
|
|
new_head_mask.sum(),
|
|
new_head_mask.sum() / new_head_mask.numel() * 100,
|
|
)
|
|
|
|
logger.info("Final head mask")
|
|
print_2d_tensor(head_mask)
|
|
np.save(os.path.join(args.output_dir, "head_mask.npy"), head_mask.detach().cpu().numpy())
|
|
|
|
return head_mask
|
|
|
|
|
|
def prune_heads(args, model, eval_dataloader, head_mask):
|
|
"""This method shows how to prune head (remove heads weights) based on
|
|
the head importance scores as described in Michel et al. (http://arxiv.org/abs/1905.10650)
|
|
"""
|
|
# Try pruning and test time speedup
|
|
# Pruning is like masking but we actually remove the masked weights
|
|
before_time = datetime.now()
|
|
_, _, preds, labels = compute_heads_importance(
|
|
args, model, eval_dataloader, compute_entropy=False, compute_importance=False, head_mask=head_mask
|
|
)
|
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
|
score_masking = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
|
original_time = datetime.now() - before_time
|
|
|
|
original_num_params = sum(p.numel() for p in model.parameters())
|
|
heads_to_prune = {
|
|
layer: (1 - head_mask[layer].long()).nonzero().squeeze().tolist() for layer in range(len(head_mask))
|
|
}
|
|
|
|
assert sum(len(h) for h in heads_to_prune.values()) == (1 - head_mask.long()).sum().item()
|
|
model.prune_heads(heads_to_prune)
|
|
pruned_num_params = sum(p.numel() for p in model.parameters())
|
|
|
|
before_time = datetime.now()
|
|
_, _, preds, labels = compute_heads_importance(
|
|
args,
|
|
model,
|
|
eval_dataloader,
|
|
compute_entropy=False,
|
|
compute_importance=False,
|
|
head_mask=None,
|
|
actually_pruned=True,
|
|
)
|
|
preds = np.argmax(preds, axis=1) if args.output_mode == "classification" else np.squeeze(preds)
|
|
score_pruning = glue_compute_metrics(args.task_name, preds, labels)[args.metric_name]
|
|
new_time = datetime.now() - before_time
|
|
|
|
logger.info(
|
|
"Pruning: original num of params: %.2e, after pruning %.2e (%.1f percents)",
|
|
original_num_params,
|
|
pruned_num_params,
|
|
pruned_num_params / original_num_params * 100,
|
|
)
|
|
logger.info("Pruning: score with masking: %f score with pruning: %f", score_masking, score_pruning)
|
|
logger.info("Pruning: speed ratio (new timing / original timing): %f percents", original_time / new_time * 100)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
# Required parameters
|
|
parser.add_argument(
|
|
"--data_dir",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
|
)
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="Path to pretrained model or model identifier from huggingface.co/models",
|
|
)
|
|
parser.add_argument(
|
|
"--task_name",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The name of the task to train selected in the list: " + ", ".join(glue_processors.keys()),
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
default=None,
|
|
type=str,
|
|
required=True,
|
|
help="The output directory where the model predictions and checkpoints will be written.",
|
|
)
|
|
|
|
# Other parameters
|
|
parser.add_argument(
|
|
"--config_name",
|
|
default="",
|
|
type=str,
|
|
help="Pretrained config name or path if not the same as model_name_or_path",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer_name",
|
|
default="",
|
|
type=str,
|
|
help="Pretrained tokenizer name or path if not the same as model_name_or_path",
|
|
)
|
|
parser.add_argument(
|
|
"--cache_dir",
|
|
default=None,
|
|
type=str,
|
|
help="Where do you want to store the pre-trained models downloaded from huggingface.co",
|
|
)
|
|
parser.add_argument(
|
|
"--data_subset", type=int, default=-1, help="If > 0: limit the data to a subset of data_subset instances."
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite_output_dir", action="store_true", help="Whether to overwrite data in output directory"
|
|
)
|
|
parser.add_argument(
|
|
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--dont_normalize_importance_by_layer", action="store_true", help="Don't normalize importance score by layers"
|
|
)
|
|
parser.add_argument(
|
|
"--dont_normalize_global_importance",
|
|
action="store_true",
|
|
help="Don't normalize all importance scores between 0 and 1",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--try_masking", action="store_true", help="Whether to try to mask head until a threshold of accuracy."
|
|
)
|
|
parser.add_argument(
|
|
"--masking_threshold",
|
|
default=0.9,
|
|
type=float,
|
|
help="masking threshold in term of metrics (stop masking when metric < threshold * original metric value).",
|
|
)
|
|
parser.add_argument(
|
|
"--masking_amount", default=0.1, type=float, help="Amount to heads to masking at each masking step."
|
|
)
|
|
parser.add_argument("--metric_name", default="acc", type=str, help="Metric to use for head masking.")
|
|
|
|
parser.add_argument(
|
|
"--max_seq_length",
|
|
default=128,
|
|
type=int,
|
|
help=(
|
|
"The maximum total input sequence length after WordPiece tokenization. \n"
|
|
"Sequences longer than this will be truncated, sequences shorter padded."
|
|
),
|
|
)
|
|
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
|
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
|
|
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
|
|
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
|
|
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
|
|
args = parser.parse_args()
|
|
|
|
if args.server_ip and args.server_port:
|
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
|
import ptvsd
|
|
|
|
print("Waiting for debugger attach")
|
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
|
ptvsd.wait_for_attach()
|
|
|
|
# Setup devices and distributed training
|
|
if args.local_rank == -1 or args.no_cuda:
|
|
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
|
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
|
|
else:
|
|
torch.cuda.set_device(args.local_rank)
|
|
args.device = torch.device("cuda", args.local_rank)
|
|
args.n_gpu = 1
|
|
torch.distributed.init_process_group(backend="nccl") # Initializes the distributed backend
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
|
|
logger.info("device: {} n_gpu: {}, distributed: {}".format(args.device, args.n_gpu, bool(args.local_rank != -1)))
|
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
if is_main_process(args.local_rank):
|
|
transformers.utils.logging.set_verbosity_info()
|
|
transformers.utils.logging.enable_default_handler()
|
|
transformers.utils.logging.enable_explicit_format()
|
|
|
|
# Set seeds
|
|
set_seed(args.seed)
|
|
|
|
# Prepare GLUE task
|
|
args.task_name = args.task_name.lower()
|
|
if args.task_name not in glue_processors:
|
|
raise ValueError("Task not found: %s" % (args.task_name))
|
|
processor = glue_processors[args.task_name]()
|
|
args.output_mode = glue_output_modes[args.task_name]
|
|
label_list = processor.get_labels()
|
|
num_labels = len(label_list)
|
|
|
|
# Load pretrained model and tokenizer
|
|
#
|
|
# Distributed training:
|
|
# The .from_pretrained methods guarantee that only one local process can concurrently
|
|
# download model & vocab.
|
|
|
|
config = AutoConfig.from_pretrained(
|
|
args.config_name if args.config_name else args.model_name_or_path,
|
|
num_labels=num_labels,
|
|
finetuning_task=args.task_name,
|
|
output_attentions=True,
|
|
cache_dir=args.cache_dir,
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
|
cache_dir=args.cache_dir,
|
|
)
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
args.model_name_or_path,
|
|
from_tf=bool(".ckpt" in args.model_name_or_path),
|
|
config=config,
|
|
cache_dir=args.cache_dir,
|
|
)
|
|
|
|
# Distributed and parallel training
|
|
model.to(args.device)
|
|
if args.local_rank != -1:
|
|
model = nn.parallel.DistributedDataParallel(
|
|
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
|
)
|
|
elif args.n_gpu > 1:
|
|
model = nn.DataParallel(model)
|
|
|
|
# Print/save training arguments
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
torch.save(args, os.path.join(args.output_dir, "run_args.bin"))
|
|
logger.info("Training/evaluation parameters %s", args)
|
|
|
|
# Prepare dataset for the GLUE task
|
|
eval_dataset = GlueDataset(args, tokenizer=tokenizer, mode="dev")
|
|
if args.data_subset > 0:
|
|
eval_dataset = Subset(eval_dataset, list(range(min(args.data_subset, len(eval_dataset)))))
|
|
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
|
|
eval_dataloader = DataLoader(
|
|
eval_dataset, sampler=eval_sampler, batch_size=args.batch_size, collate_fn=default_data_collator
|
|
)
|
|
|
|
# Compute head entropy and importance score
|
|
compute_heads_importance(args, model, eval_dataloader)
|
|
|
|
# Try head masking (set heads to zero until the score goes under a threshole)
|
|
# and head pruning (remove masked heads and see the effect on the network)
|
|
if args.try_masking and args.masking_threshold > 0.0 and args.masking_threshold < 1.0:
|
|
head_mask = mask_heads(args, model, eval_dataloader)
|
|
prune_heads(args, model, eval_dataloader, head_mask)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|