Merge pull request #2292 from patrickvonplaten/add_cached_past_for_language_generation

Add cached past for language generation
This commit is contained in:
Thomas Wolf 2019-12-27 10:33:27 +01:00 committed by GitHub
commit 492bea9aa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 13 deletions

View File

@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward(
self,
input_ids=None,

View File

@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
def forward(
self,
input_ids=None,

View File

@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
return self.out_layer
else:
return self.crit.out_layers[-1]
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
inputs = {"input_ids": input_ids}
# if past is defined in model kwargs then use it for faster decoding
if "past" in model_kwargs and model_kwargs["past"]:
inputs["mems"] = model_kwargs["past"]
return inputs

View File

@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
@torch.no_grad()
def generate(
self,
@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
# TODO: add cached compute states
pasts = None
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
for i in range(batch_size):
@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
pasts = None # self.prepare_pasts()
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
beam_words = input_ids.new([x[1] for x in next_batch_beam])
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
# re-order batch and internal states
# re-order batch
input_ids = input_ids[beam_idx, :]
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
# TODO: Activate cache
# for k in cache.keys():
# if k != 'slen':
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
# re-order internal states
if past:
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
# update current length
cur_len = cur_len + 1

View File

@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
)
target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
# if past is defined in model kwargs then use it for faster decoding
if "past" in model_kwargs and model_kwargs["past"]:
inputs["mems"] = model_kwargs["past"]
return inputs
def forward(
self,