add head masking and pruning to openai GPT
This commit is contained in:
parent
b860e47cf5
commit
f12007e421
|
@ -36,6 +36,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from .file_utils import cached_path, CONFIG_NAME, WEIGHTS_NAME
|
||||
from .modeling import BertLayerNorm as LayerNorm
|
||||
from .modeling_gpt2 import prune_conv1d_layer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -256,7 +257,7 @@ class Conv1D(nn.Module):
|
|||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False):
|
||||
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
|
@ -265,13 +266,31 @@ class Attention(nn.Module):
|
|||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
|
||||
self.output_attentions = output_attentions
|
||||
self.keep_multihead_output = keep_multihead_output
|
||||
self.multihead_output = None
|
||||
|
||||
self.c_attn = Conv1D(n_state * 3, 1, nx)
|
||||
self.c_proj = Conv1D(n_state, 1, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def _attn(self, q, k, v):
|
||||
def prune_heads(self, heads):
|
||||
mask = torch.ones(self.n_head, self.split_size // self.n_head)
|
||||
for head in heads:
|
||||
mask[head] = 0
|
||||
mask = mask.view(-1).contiguous().eq(1)
|
||||
index = torch.arange(len(mask))[mask].long()
|
||||
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
|
||||
# Prune conv1d layers
|
||||
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
||||
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
||||
# Update hyper params
|
||||
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
||||
self.n_head = self.n_head - len(heads)
|
||||
|
||||
def _attn(self, q, k, v, head_mask=None):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
|
@ -282,6 +301,11 @@ class Attention(nn.Module):
|
|||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
w = self.attn_dropout(w)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
w = w * head_mask
|
||||
|
||||
if self.output_attentions:
|
||||
return w, torch.matmul(w, v)
|
||||
return torch.matmul(w, v)
|
||||
|
@ -299,13 +323,18 @@ class Attention(nn.Module):
|
|||
else:
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, head_mask=None):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
a = self._attn(query, key, value)
|
||||
|
||||
a = self._attn(query, key, value, head_mask)
|
||||
if self.keep_multihead_output:
|
||||
self.multihead_output = a
|
||||
self.multihead_output.retain_grad()
|
||||
|
||||
if self.output_attentions:
|
||||
attentions, a = a
|
||||
a = self.merge_heads(a)
|
||||
|
@ -332,17 +361,17 @@ class MLP(nn.Module):
|
|||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False, output_attentions=False):
|
||||
def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
|
||||
super(Block, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.output_attentions = output_attentions
|
||||
self.attn = Attention(nx, n_ctx, config, scale, output_attentions)
|
||||
self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
|
||||
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(self, x):
|
||||
a = self.attn(x)
|
||||
def forward(self, x, head_mask=None):
|
||||
a = self.attn(x, head_mask=head_mask)
|
||||
if self.output_attentions:
|
||||
attentions, a = a
|
||||
n = self.ln_1(x + a)
|
||||
|
@ -614,13 +643,14 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(OpenAIGPTModel, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.tokens_embed = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
|
||||
self.positions_embed = nn.Embedding(config.n_positions, config.n_embd)
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions)
|
||||
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
|
||||
self.apply(self.init_weights)
|
||||
|
@ -639,7 +669,20 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||
# Copy word embeddings from the previous weights
|
||||
self.tokens_embed.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
||||
def prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.h[layer].attn.prune_heads(heads)
|
||||
|
||||
def get_multihead_outputs(self):
|
||||
""" Gather all multi-head outputs.
|
||||
Return: list (layers) of multihead module outputs with gradients
|
||||
"""
|
||||
return [h.attn.multihead_output for h in self.h]
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, head_mask=None):
|
||||
if position_ids is None:
|
||||
# This was used when we had a single embedding matrice from position and token embeddings
|
||||
# start = self.config.vocab_size + self.config.n_special
|
||||
|
@ -648,6 +691,17 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||
position_ids = torch.arange(input_ids.size(-1), dtype=torch.long, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we mask the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
head_mask = (1.0 - head_mask)
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
@ -664,11 +718,12 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||
|
||||
all_attentions = []
|
||||
for block in self.h:
|
||||
outputs = block(hidden_states, head_mask)
|
||||
if self.output_attentions:
|
||||
attentions, hidden_states = block(hidden_states)
|
||||
attentions, hidden_states = outputs
|
||||
all_attentions.append(attentions)
|
||||
else:
|
||||
hidden_states = block(hidden_states)
|
||||
hidden_states = outputs
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
if self.output_attentions:
|
||||
return all_attentions, hidden_states.view(*output_shape)
|
||||
|
@ -731,9 +786,10 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(OpenAIGPTLMHeadModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.apply(self.init_weights)
|
||||
|
||||
|
@ -745,8 +801,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, head_mask=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
@ -825,9 +881,10 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config, output_attentions=False):
|
||||
def __init__(self, config, output_attentions=False, keep_multihead_output=False):
|
||||
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions)
|
||||
self.transformer = OpenAIGPTModel(config, output_attentions=output_attentions,
|
||||
keep_multihead_output=keep_multihead_output)
|
||||
self.lm_head = OpenAIGPTLMHead(self.transformer.tokens_embed.weight, config)
|
||||
self.multiple_choice_head = OpenAIGPTMultipleChoiceHead(config)
|
||||
self.apply(self.init_weights)
|
||||
|
@ -840,8 +897,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||
self.transformer.set_num_special_tokens(num_special_tokens)
|
||||
self.lm_head.set_embeddings_weights(self.transformer.tokens_embed.weight, predict_special_tokens=predict_special_tokens)
|
||||
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids)
|
||||
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
|
||||
position_ids=None, head_mask=None):
|
||||
hidden_states = self.transformer(input_ids, position_ids, token_type_ids, head_mask)
|
||||
if self.transformer.output_attentions:
|
||||
all_attentions, hidden_states = hidden_states
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
|
|
@ -182,6 +182,73 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||
[list(l.size()) for l in result["loss"]],
|
||||
[[], []])
|
||||
|
||||
def create_and_check_openai_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
|
||||
model = model_class(config=config, keep_multihead_output=True)
|
||||
model.eval()
|
||||
head_mask = torch.ones(self.n_head).to(input_ids.device)
|
||||
head_mask[0] = 0.0
|
||||
head_mask[-1] = 0.0 # Mask all but the first and last heads
|
||||
if isinstance(model, OpenAIGPTDoubleHeadsModel):
|
||||
output = model(input_ids, mc_token_ids, head_mask=head_mask)
|
||||
else:
|
||||
output = model(input_ids, head_mask=head_mask)
|
||||
|
||||
output = sum(t.sum() for t in output[:-1])
|
||||
output = output.sum()
|
||||
output.backward()
|
||||
multihead_outputs = (model if isinstance(model, OpenAIGPTModel) else model.transformer).get_multihead_outputs()
|
||||
|
||||
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(multihead_outputs[0].size()),
|
||||
[self.batch_size * self.n_choices, self.n_head,
|
||||
self.seq_length, self.n_embd // self.n_head])
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
|
||||
0)
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, 0, :, :].nonzero()),
|
||||
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
|
||||
self.parent.assertEqual(
|
||||
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
|
||||
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
|
||||
|
||||
def create_and_check_openai_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
|
||||
mc_labels, lm_labels, mc_token_ids):
|
||||
for model_class in (OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel):
|
||||
model = model_class(config=config, keep_multihead_output=True)
|
||||
model.eval()
|
||||
transformer = model if isinstance(model, OpenAIGPTModel) else model.transformer
|
||||
heads_to_prune = {0: list(range(1, self.n_head)),
|
||||
-1: [0]}
|
||||
transformer.prune_heads(heads_to_prune)
|
||||
if isinstance(model, OpenAIGPTDoubleHeadsModel):
|
||||
output = model(input_ids, mc_token_ids)
|
||||
else:
|
||||
output = model(input_ids)
|
||||
|
||||
output = sum(t.sum() for t in output[:-1])
|
||||
output = output.sum()
|
||||
output.backward()
|
||||
multihead_outputs = transformer.get_multihead_outputs()
|
||||
|
||||
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
|
||||
self.parent.assertListEqual(
|
||||
list(multihead_outputs[0].size()),
|
||||
[self.batch_size * self.n_choices, 1,
|
||||
self.seq_length, self.n_embd // self.n_head])
|
||||
self.parent.assertListEqual(
|
||||
list(multihead_outputs[1].size()),
|
||||
[self.batch_size * self.n_choices, self.n_head,
|
||||
self.seq_length, self.n_embd // self.n_head])
|
||||
self.parent.assertListEqual(
|
||||
list(multihead_outputs[-1].size()),
|
||||
[self.batch_size * self.n_choices, self.n_head-1,
|
||||
self.seq_length, self.n_embd // self.n_head])
|
||||
|
||||
|
||||
def test_default(self):
|
||||
self.run_tester(OpenAIGPTModelTest.OpenAIGPTModelTester(self))
|
||||
|
||||
|
@ -220,6 +287,9 @@ class OpenAIGPTModelTest(unittest.TestCase):
|
|||
tester.check_openai_double_heads_output(output_result)
|
||||
tester.check_openai_double_heads_loss_output(output_result)
|
||||
|
||||
tester.create_and_check_openai_for_headmasking(*config_and_inputs)
|
||||
tester.create_and_check_openai_for_head_pruning(*config_and_inputs)
|
||||
|
||||
@classmethod
|
||||
def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
|
||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||
|
|
Loading…
Reference in New Issue