From 023f51fe16e34e0ca2b5598791ae508874d5b443 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 18 Jan 2023 11:24:37 +0100 Subject: [PATCH] `blip` support for training (#21021) * `blip` support for training * remove labels creation * remove unneeded `decoder_input_ids` creation * final changes - add colab link to documentation - reduction = mean for loss * fix nits * update link * clearer error message --- docs/source/en/model_doc/blip.mdx | 4 + src/transformers/models/blip/modeling_blip.py | 55 ++- .../models/blip/modeling_blip_text.py | 2 +- tests/models/blip/test_modeling_blip.py | 318 +++++++++++++++++- 4 files changed, 364 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/blip.mdx b/docs/source/en/model_doc/blip.mdx index 81f51bfd68..42116f4869 100644 --- a/docs/source/en/model_doc/blip.mdx +++ b/docs/source/en/model_doc/blip.mdx @@ -31,6 +31,10 @@ However, most existing pre-trained models only excel in either understanding-bas This model was contributed by [ybelkada](https://huggingface.co/ybelkada). The original code can be found [here](https://github.com/salesforce/BLIP). +## Resources + +- [Jupyter notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_blip.ipynb) on how to fine-tune BLIP for image captioning on a custom dataset + ## BlipConfig diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 8856fe04e8..f00c9f9cab 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -1014,6 +1014,7 @@ class BlipForConditionalGeneration(BlipPreTrainedModel): encoder_hidden_states=image_embeds, labels=labels, return_dict=return_dict, + reduction="mean", ) if not return_dict: @@ -1125,7 +1126,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): self.text_decoder = BlipTextLMHeadModel(config.text_config) self.decoder_pad_token_id = config.text_config.pad_token_id - self.decoder_bos_token_id = config.text_config.bos_token_id + self.decoder_start_token_id = config.text_config.bos_token_id # Initialize weights and apply final processing self.post_init() @@ -1133,6 +1134,19 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding + # Adapted from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right + def _shift_right(self, input_ids): + pad_token_id = self.decoder_pad_token_id + + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + @add_start_docstrings_to_model_forward(BLIP_VISION_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BlipTextVisionModelOutput, config_class=BlipVisionConfig) def forward( @@ -1168,8 +1182,14 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): >>> outputs = model(**inputs) ```""" + if labels is None and decoder_input_ids is None: + raise ValueError( + "Either `decoder_input_ids` or `labels` should be passed when calling `forward` with" + " `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you" + " are using the model for inference make sure that `decoder_input_ids` is passed." + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size = input_ids.shape[0] vision_outputs = self.vision_model( pixel_values=pixel_values, @@ -1191,11 +1211,11 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): question_embeds = question_embeds[0] if not return_dict else question_embeds.last_hidden_state - if decoder_input_ids is None: - decoder_input_ids = torch.LongTensor([self.decoder_bos_token_id]).repeat((batch_size, 1)) - - if labels is None: - labels = decoder_input_ids.masked_fill(decoder_input_ids == self.decoder_pad_token_id, -100) + if labels is not None and decoder_input_ids is None: + # get decoder inputs from shifting lm labels to the right - this is used in training mode + decoder_input_ids = self._shift_right(labels) + # replace possible -100 values in labels by `pad_token_id` + labels = labels.masked_fill(labels == self.decoder_pad_token_id, -100) answer_output = self.text_decoder( input_ids=decoder_input_ids, @@ -1204,10 +1224,13 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): encoder_attention_mask=attention_mask, labels=labels, return_dict=return_dict, - reduction="none", + reduction="mean", ) - decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean() + if labels is not None: + decoder_loss = answer_output.loss.mean() if return_dict else answer_output[0].mean() + else: + decoder_loss = None if not return_dict: outputs = (decoder_loss, image_embeds, vision_outputs[0]) + vision_outputs[2:] @@ -1288,7 +1311,7 @@ class BlipForQuestionAnswering(BlipPreTrainedModel): question_attention_mask = torch.ones(question_embeds.size()[:-1], dtype=torch.long).to(question_embeds.device) bos_ids = torch.full( - (question_embeds.size(0), 1), fill_value=self.decoder_bos_token_id, device=question_embeds.device + (question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device ) outputs = self.text_decoder.generate( @@ -1330,8 +1353,16 @@ class BlipForImageTextRetrieval(BlipPreTrainedModel): # image text matching head self.itm_head = nn.Linear(config.text_config.hidden_size, 2) - self.decoder_pad_token_id = config.text_config.pad_token_id - self.decoder_bos_token_id = config.text_config.bos_token_id + self.decoder_pad_token_id = ( + config.text_config.pad_token_id + if not hasattr(config, "decoder_pad_token_id") + else config.decoder_pad_token_id + ) + self.decoder_start_token_id = ( + config.text_config.bos_token_id + if not hasattr(config, "decoder_start_token_id") + else config.decoder_start_token_id + ) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index fac1a906ef..2fd0e3c861 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -731,7 +731,7 @@ class BlipTextModel(BlipTextPreTrainedModel): past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. diff --git a/tests/models/blip/test_modeling_blip.py b/tests/models/blip/test_modeling_blip.py index 7431df7744..800bd67989 100644 --- a/tests/models/blip/test_modeling_blip.py +++ b/tests/models/blip/test_modeling_blip.py @@ -521,7 +521,7 @@ class BlipModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(model) -class BlipTextImageModelsModelTester: +class BlipTextRetrievalModelTester: def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): if text_kwargs is None: @@ -569,13 +569,319 @@ class BlipTextImageModelsModelTester: return config, inputs_dict +class BlipTextImageModelsModelTester: + def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): + + if text_kwargs is None: + text_kwargs = {} + if vision_kwargs is None: + vision_kwargs = {} + + self.parent = parent + self.text_model_tester = BlipTextModelTester(parent, **text_kwargs) + self.vision_model_tester = BlipVisionModelTester(parent, **vision_kwargs) + self.is_training = is_training + + def prepare_config_and_inputs(self): + text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs() + vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs() + + config = self.get_config() + + return config, input_ids, attention_mask, pixel_values + + def get_config(self): + return BlipConfig.from_text_vision_configs( + self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64 + ) + + def create_and_check_model(self, config, input_ids, attention_mask, pixel_values): + model = BlipModel(config).to(torch_device).eval() + with torch.no_grad(): + result = model(input_ids, pixel_values, attention_mask) + self.parent.assertEqual( + result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size) + ) + self.parent.assertEqual( + result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size) + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, input_ids, attention_mask, pixel_values = config_and_inputs + inputs_dict = { + "input_ids": input_ids, + "labels": input_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + return config, inputs_dict + + +@require_torch +@require_vision +class BlipVQAModelTest(unittest.TestCase): + all_model_classes = (BlipForQuestionAnswering,) if is_torch_available() else () + + def setUp(self): + self.model_tester = BlipModelTester(self) + + def _prepare_inputs_for_vqa(self): + _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict["labels"] = inputs_dict["input_ids"] + inputs_dict.pop("return_loss") + return inputs_dict + + def test_class_name_consistency(self): + """ + Tests that all VQA models have a class name that ends with "ForQuestionAnswering" + """ + for model_class in self.all_model_classes: + model = model_class(self.model_tester.get_config()) + self.assertTrue( + model.__class__.__name__.endswith("ForQuestionAnswering"), + f"Class name should end with 'ForVisualQuestionAnswering' got {model.__class__.__name__}", + ) + + def test_training(self): + """ + Tests that all VQA models can be trained on a single batch + """ + for model_class in self.all_model_classes: + model = model_class(self.model_tester.get_config()).to(torch_device) + model.train() + loss = model(**self._prepare_inputs_for_vqa()).loss + loss.backward() + + # verify the gradients are not None + for name, param in model.named_parameters(): + self.assertIsNotNone(param.grad, f"Gradients should not be None - got {param.grad} for {name}") + + def test_forward_signature(self): + """ + Test if the forward function has the expected arguments. + """ + for model_class in self.all_model_classes: + model = model_class(self.model_tester.get_config()) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so args are the first n entries + args = list(signature.parameters.keys()) + expected_args = [ + "input_ids", + "attention_mask", + "labels", + "decoder_input_ids", + "decoder_attention_mask", + ] + for arg in expected_args: + self.assertTrue( + arg in args, + f"Argument {arg} of forward function signature should include {arg}. Found {args}.", + ) + + +@require_torch +class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (BlipForImageTextRetrieval,) if is_torch_available() else () + fx_compatible = False + test_head_masking = False + test_pruning = False + test_resize_embeddings = False + test_attention_outputs = False + test_torchscript = False + + def setUp(self): + self.model_tester = BlipTextRetrievalModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="Hidden_states is tested in individual model tests") + def test_hidden_states_output(self): + pass + + @unittest.skip(reason="Inputs_embeds is tested in individual model tests") + def test_inputs_embeds(self): + pass + + @unittest.skip(reason="Retain_grad is tested in individual model tests") + def test_retain_grad_hidden_states_attentions(self): + pass + + @unittest.skip(reason="BlipModel does not have input/output embeddings") + def test_model_common_attributes(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + if model.config.is_encoder_decoder: + expected_arg_names = [ + "input_ids", + "attention_mask", + "decoder_input_ids", + "decoder_attention_mask", + ] + expected_arg_names.extend( + ["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] + if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names + else ["encoder_outputs"] + ) + self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) + else: + expected_arg_names = ["input_ids"] if model_class != BlipForConditionalGeneration else ["pixel_values"] + self.assertListEqual(arg_names[:1], expected_arg_names) + + def test_training(self): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes[:-1]: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + model = model_class(config) + model.to(torch_device) + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + # hardcode labels to be the same as input_ids + inputs["labels"] = inputs["input_ids"] + + loss = model(**inputs).loss + loss.backward() + + def test_training_gradient_checkpointing(self): + if not self.model_tester.is_training: + return + + for model_class in self.all_model_classes[:-1]: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.use_cache = False + config.return_dict = True + + model = model_class(config) + model.to(torch_device) + model.gradient_checkpointing_enable() + model.train() + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + # hardcode labels to be the same as input_ids + inputs["labels"] = inputs["input_ids"] + + loss = model(**inputs).loss + loss.backward() + + # override as the `logit_scale` parameter initilization is different for Blip + def test_initialization(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + configs_no_init = _config_zero_init(config) + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + for name, param in model.named_parameters(): + if param.requires_grad: + # check if `logit_scale` is initilized as per the original implementation + if name == "logit_scale": + self.assertAlmostEqual( + param.data.item(), + np.log(1 / 0.07), + delta=1e-3, + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + else: + self.assertIn( + ((param.data.mean() * 1e9).round() / 1e9).item(), + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def _create_and_check_torchscript(self, config, inputs_dict): + if not self.test_torchscript: + return + + configs_no_init = _config_zero_init(config) # To be sure we have no Nan + configs_no_init.torchscript = True + configs_no_init.return_dict = False + for model_class in self.all_model_classes: + model = model_class(config=configs_no_init) + model.to(torch_device) + model.eval() + + try: + input_ids = inputs_dict["input_ids"] + pixel_values = inputs_dict["pixel_values"] # Blip needs pixel_values + traced_model = torch.jit.trace(model, (input_ids, pixel_values)) + except RuntimeError: + self.fail("Couldn't trace module.") + + with tempfile.TemporaryDirectory() as tmp_dir_name: + pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") + + try: + torch.jit.save(traced_model, pt_file_name) + except Exception: + self.fail("Couldn't save module.") + + try: + loaded_model = torch.jit.load(pt_file_name) + except Exception: + self.fail("Couldn't load module.") + + model.to(torch_device) + model.eval() + + loaded_model.to(torch_device) + loaded_model.eval() + + model_state_dict = model.state_dict() + loaded_model_state_dict = loaded_model.state_dict() + + self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys())) + + models_equal = True + for layer_name, p1 in model_state_dict.items(): + p2 = loaded_model_state_dict[layer_name] + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + def test_load_vision_text_config(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Save BlipConfig and check if we can load BlipVisionConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + vision_config = BlipVisionConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict()) + + # Save BlipConfig and check if we can load BlipTextConfig from it + with tempfile.TemporaryDirectory() as tmp_dir_name: + config.save_pretrained(tmp_dir_name) + text_config = BlipTextConfig.from_pretrained(tmp_dir_name) + self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) + + @slow + def test_model_from_pretrained(self): + for model_name in BLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = BlipModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + @require_torch class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( BlipForConditionalGeneration, BlipForQuestionAnswering, - BlipForImageTextRetrieval, ) if is_torch_available() else () @@ -648,6 +954,10 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): model.to(torch_device) model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + # hardcode labels to be the same as input_ids + inputs["labels"] = inputs["input_ids"] + loss = model(**inputs).loss loss.backward() @@ -665,6 +975,10 @@ class BlipTextImageModelTest(ModelTesterMixin, unittest.TestCase): model.gradient_checkpointing_enable() model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + + # hardcode labels to be the same as input_ids + inputs["labels"] = inputs["input_ids"] + loss = model(**inputs).loss loss.backward()