Patch-past-refactor (#21050)

* small patches, forgot a line

* refactor PT

* the actual fix
This commit is contained in:
Arthur 2023-01-09 18:12:13 +01:00 committed by GitHub
parent 48d4e147d8
commit e3ecbaa4ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 5 additions and 8 deletions

View File

@ -729,13 +729,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
past_key_values = decoder_inputs.get("past_key_values")
if past_key_values is None:
past_key_values = decoder_inputs.get("past") # e.g. on TF GPT2
input_dict = {
"pixel_values": None, # needs to be passed to make Keras.layer.__call__ happy
"attention_mask": attention_mask,

View File

@ -649,9 +649,9 @@ class VisionEncoderDecoderModel(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,

View File

@ -3333,7 +3333,7 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {

View File

@ -55,7 +55,6 @@ class ImageToTextPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
)
@require_tf
@unittest.skip("Arthur will fix me!")
def test_small_model_tf(self):
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-vit-gpt2", framework="tf")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"