Patch-past-refactor (#21050)
* small patches, forgot a line * refactor PT * the actual fix
This commit is contained in:
parent
48d4e147d8
commit
e3ecbaa4ab
|
@ -729,13 +729,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
|
|||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
past_key_values = decoder_inputs.get("past_key_values")
|
||||
if past_key_values is None:
|
||||
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
|
||||
input_dict = {
|
||||
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -649,9 +649,9 @@ class VisionEncoderDecoderModel(PreTrainedModel):
|
|||
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
|
|
|
@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
|
|||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
|
|
|
@ -55,7 +55,6 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
|
|||
)
|
||||
|
||||
@require_tf
|
||||
@unittest.skip("Arthur will fix me!")
|
||||
def test_small_model_tf(self):
|
||||
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf")
|
||||
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
|
||||
|
|
Loading…
Reference in New Issue