Merge pull request #2292 from patrickvonplaten/add_cached_past_for_language_generation
Add cached past for language generation
This commit is contained in:
commit
492bea9aa0
|
@ -490,6 +490,15 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
|||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if "past" in kwargs and kwargs["past"]:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
inputs = {"input_ids": input_ids}
|
||||
inputs.update(kwargs)
|
||||
return inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
|
|
@ -559,6 +559,15 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if "past" in kwargs and kwargs["past"]:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
inputs = {"input_ids": input_ids}
|
||||
inputs.update(kwargs)
|
||||
return inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
|
|
|
@ -930,3 +930,12 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||
return self.out_layer
|
||||
else:
|
||||
return self.crit.out_layers[-1]
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
|
||||
inputs = {"input_ids": input_ids}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if "past" in model_kwargs and model_kwargs["past"]:
|
||||
inputs["mems"] = model_kwargs["past"]
|
||||
|
||||
return inputs
|
||||
|
|
|
@ -539,6 +539,17 @@ class PreTrainedModel(nn.Module):
|
|||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
return {"input_ids": input_ids}
|
||||
|
||||
def _do_output_past(self, outputs):
|
||||
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
|
||||
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
|
||||
|
||||
if has_output_past and not has_mem_len and len(outputs) > 1:
|
||||
return True
|
||||
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
|
@ -757,14 +768,17 @@ class PreTrainedModel(nn.Module):
|
|||
# current position / max lengths / length of generated sentences / unfinished sentences
|
||||
unfinished_sents = input_ids.new(batch_size).fill_(1)
|
||||
|
||||
# TODO: add cached compute states
|
||||
pasts = None
|
||||
past = None
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
outputs = self(**model_inputs)
|
||||
next_token_logits = outputs[0][:, -1, :]
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(batch_size):
|
||||
|
@ -838,15 +852,19 @@ class PreTrainedModel(nn.Module):
|
|||
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
|
||||
|
||||
# cache compute states
|
||||
pasts = None # self.prepare_pasts()
|
||||
past = None
|
||||
|
||||
# done sentences
|
||||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
|
||||
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
|
||||
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
scores = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# if model has past, then set the past variable to speed up decoding
|
||||
if self._do_output_past(outputs):
|
||||
past = outputs[1]
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
|
@ -935,13 +953,22 @@ class PreTrainedModel(nn.Module):
|
|||
beam_words = input_ids.new([x[1] for x in next_batch_beam])
|
||||
beam_idx = input_ids.new([x[2] for x in next_batch_beam])
|
||||
|
||||
# re-order batch and internal states
|
||||
# re-order batch
|
||||
input_ids = input_ids[beam_idx, :]
|
||||
input_ids = torch.cat([input_ids, beam_words.unsqueeze(1)], dim=-1)
|
||||
# TODO: Activate cache
|
||||
# for k in cache.keys():
|
||||
# if k != 'slen':
|
||||
# cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])
|
||||
|
||||
# re-order internal states
|
||||
if past:
|
||||
reordered_past = []
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from layer past batch dim
|
||||
# batch dim of `past` and `mems` is at 2nd position
|
||||
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
|
||||
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
|
||||
# check that shape matches
|
||||
assert reordered_layer_past.shape == layer_past.shape
|
||||
reordered_past.append(reordered_layer_past)
|
||||
past = tuple(reordered_past)
|
||||
|
||||
# update current length
|
||||
cur_len = cur_len + 1
|
||||
|
|
|
@ -1028,7 +1028,13 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||
)
|
||||
target_mapping[0, 0, -1] = 1.0
|
||||
|
||||
return {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
|
||||
|
||||
# if past is defined in model kwargs then use it for faster decoding
|
||||
if "past" in model_kwargs and model_kwargs["past"]:
|
||||
inputs["mems"] = model_kwargs["past"]
|
||||
|
||||
return inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue