Add input_embeds functionality to gpt_neo Causal LM (#25659)
* Updated gpt_neo causalLM to support using input embeddings for generation * added indentation * Did make fixup
This commit is contained in:
parent
908f853688
commit
977b2f05d5
|
@ -680,7 +680,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past_key_values:
|
||||
|
@ -698,14 +698,24 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
|||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
|
||||
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
|
|
Loading…
Reference in New Issue