add head masking and pruning to openai GPT

This commit is contained in:
thomwolf 2019-06-17 14:19:40 +02:00
parent b860e47cf5
commit f12007e421
2 changed files with 149 additions and 21 deletions

View File

@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
from .modeling import BertLayerNorm as LayerNorm
from .modeling_gpt2 import prune_conv1d_layer
logger = logging.getLogger(__name__)
@ -256,7 +257,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
@ -265,13 +266,31 @@ class Attention(nn.Module):
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, 1, nx)
self.c_proj = Conv1D(n_state, 1, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v):
def prune_heads(self, heads):
mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k)
if self.scale:
w = w / math.sqrt(v.size(-1))
@ -282,6 +301,11 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
if self.output_attentions:
return w, torch.matmul(w, v)
return torch.matmul(w, v)
@ -299,13 +323,18 @@ class Attention(nn.Module):
else:
return x.permute(0, 2, 1, 3)
def forward(self, x):
def forward(self, x, head_mask=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
a = self._attn(query, key, value)
a = self._attn(query, key, value, head_mask)
if self.keep_multihead_output:
self.multihead_output = a
self.multihead_output.retain_grad()
if self.output_attentions:
attentions, a = a
a = self.merge_heads(a)
@ -332,17 +361,17 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
super(Block, self).__init__()
nx = config.n_embd
self.output_attentions = output_attentions
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
def forward(self, x):
a = self.attn(x)
def forward(self, x, head_mask=None):
a = self.attn(x, head_mask=head_mask)
if self.output_attentions:
attentions, a = a
n = self.ln_1(x + a)
@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(OpenAIGPTModel, self).__init__(config)
self.output_attentions = output_attentions
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.apply(self.init_weights)
@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
# Copy word embeddings from the previous weights
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None):
def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
if position_ids is None:
# This was used when we had a single embedding matrice from position and token embeddings
# start = self.config.vocab_size + self.config.n_special
@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1))
@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
all_attentions = []
for block in self.h:
outputs = block(hidden_states, head_mask)
if self.output_attentions:
attentions, hidden_states = block(hidden_states)
attentions, hidden_states = outputs
all_attentions.append(attentions)
else:
hidden_states = block(hidden_states)
hidden_states = outputs
output_shape = input_shape + (hidden_states.size(-1),)
if self.output_attentions:
return all_attentions, hidden_states.view(*output_shape)
@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(OpenAIGPTLMHeadModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.apply(self.init_weights)
@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states)
@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
```
"""
def __init__(self, config, output_attentions=False):
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
self.apply(self.init_weights)
@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
position_ids=None, head_mask=None):
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
if self.transformer.output_attentions:
all_attentions, hidden_states = hidden_states
lm_logits = self.lm_head(hidden_states)

View File

@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]],
[[], []])
def create_and_check_openai_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.ones(self.n_head).to(input_ids.device)
head_mask[0] = 0.0
head_mask[-1] = 0.0 # Mask all but the first and last heads
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, OpenAIGPTModel) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, OpenAIGPTDoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self):
self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self))
@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
tester.check_openai_double_heads_output(output_result)
tester.check_openai_double_heads_loss_output(output_result)
tester.create_and_check_openai_for_headmasking(*config_and_inputs)
tester.create_and_check_openai_for_head_pruning(*config_and_inputs)
@classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size."""