750 lines
28 KiB
Python
750 lines
28 KiB
Python
# coding=utf-8
|
|
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Testing suite for the PyTorch GroupViT model."""
|
|
|
|
import inspect
|
|
import os
|
|
import random
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import requests
|
|
|
|
from transformers import GroupViTConfig, GroupViTTextConfig, GroupViTVisionConfig
|
|
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
|
|
from transformers.utils import is_torch_available, is_vision_available
|
|
|
|
from ...test_configuration_common import ConfigTester
|
|
from ...test_modeling_common import (
|
|
ModelTesterMixin,
|
|
_config_zero_init,
|
|
floats_tensor,
|
|
ids_tensor,
|
|
random_attention_mask,
|
|
)
|
|
from ...test_pipeline_mixin import PipelineTesterMixin
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
from torch import nn
|
|
|
|
from transformers import GroupViTModel, GroupViTTextModel, GroupViTVisionModel
|
|
|
|
|
|
if is_vision_available():
|
|
from PIL import Image
|
|
|
|
from transformers import CLIPProcessor
|
|
|
|
|
|
class GroupViTVisionModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=12,
|
|
image_size=30,
|
|
patch_size=2,
|
|
num_channels=3,
|
|
is_training=True,
|
|
hidden_size=32,
|
|
depths=[6, 3, 3],
|
|
num_group_tokens=[64, 8, 0],
|
|
num_output_groups=[64, 8, 8],
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
dropout=0.1,
|
|
attention_dropout=0.1,
|
|
initializer_range=0.02,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.image_size = image_size
|
|
self.patch_size = patch_size
|
|
self.num_channels = num_channels
|
|
self.is_training = is_training
|
|
self.hidden_size = hidden_size
|
|
self.depths = depths
|
|
self.num_hidden_layers = sum(depths)
|
|
self.expected_num_hidden_layers = len(depths) + 1
|
|
self.num_group_tokens = num_group_tokens
|
|
self.num_output_groups = num_output_groups
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.dropout = dropout
|
|
self.attention_dropout = attention_dropout
|
|
self.initializer_range = initializer_range
|
|
self.scope = scope
|
|
|
|
num_patches = (image_size // patch_size) ** 2
|
|
# no [CLS] token for GroupViT
|
|
self.seq_length = num_patches
|
|
|
|
def prepare_config_and_inputs(self):
|
|
rng = random.Random(0)
|
|
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size], rng=rng)
|
|
config = self.get_config()
|
|
|
|
return config, pixel_values
|
|
|
|
def get_config(self):
|
|
return GroupViTVisionConfig(
|
|
image_size=self.image_size,
|
|
patch_size=self.patch_size,
|
|
num_channels=self.num_channels,
|
|
hidden_size=self.hidden_size,
|
|
depths=self.depths,
|
|
num_group_tokens=self.num_group_tokens,
|
|
num_output_groups=self.num_output_groups,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
dropout=self.dropout,
|
|
attention_dropout=self.attention_dropout,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def create_and_check_model(self, config, pixel_values):
|
|
model = GroupViTVisionModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
result = model(pixel_values)
|
|
self.parent.assertEqual(
|
|
result.last_hidden_state.shape, (self.batch_size, self.num_output_groups[-1], self.hidden_size)
|
|
)
|
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, pixel_values = config_and_inputs
|
|
inputs_dict = {"pixel_values": pixel_values}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class GroupViTVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
|
"""
|
|
Here we also overwrite some of the tests of test_modeling_common.py, as GROUPVIT does not use input_ids, inputs_embeds,
|
|
attention_mask and seq_length.
|
|
"""
|
|
|
|
all_model_classes = (GroupViTVisionModel,) if is_torch_available() else ()
|
|
|
|
test_pruning = False
|
|
test_torchscript = False
|
|
test_resize_embeddings = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = GroupViTVisionModelTester(self)
|
|
self.config_tester = ConfigTester(
|
|
self, config_class=GroupViTVisionConfig, has_text_modality=False, hidden_size=37
|
|
)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
@unittest.skip(reason="GroupViT does not use inputs_embeds")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@is_pt_tf_cross_test
|
|
def test_pt_tf_model_equivalence(self):
|
|
import tensorflow as tf
|
|
|
|
seed = 338
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
tf.random.set_seed(seed)
|
|
return super().test_pt_tf_model_equivalence()
|
|
|
|
def test_model_common_attributes(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
|
|
x = model.get_output_embeddings()
|
|
self.assertTrue(x is None or isinstance(x, nn.Linear))
|
|
|
|
def test_forward_signature(self):
|
|
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config)
|
|
signature = inspect.signature(model.forward)
|
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
|
arg_names = [*signature.parameters.keys()]
|
|
|
|
expected_arg_names = ["pixel_values"]
|
|
self.assertListEqual(arg_names[:1], expected_arg_names)
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_attention_outputs(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.return_dict = True
|
|
|
|
seq_len = getattr(self.model_tester, "seq_length", None)
|
|
|
|
expected_num_attention_outputs = sum(g > 0 for g in self.model_tester.num_group_tokens)
|
|
|
|
for model_class in self.all_model_classes:
|
|
inputs_dict["output_attentions"] = True
|
|
inputs_dict["output_hidden_states"] = False
|
|
config.return_dict = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
attentions = outputs.attentions
|
|
# GroupViT returns attention grouping of each stage
|
|
self.assertEqual(len(attentions), sum(g > 0 for g in self.model_tester.num_group_tokens))
|
|
|
|
# check that output_attentions also work using config
|
|
del inputs_dict["output_attentions"]
|
|
config.output_attentions = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
attentions = outputs.attentions
|
|
# GroupViT returns attention grouping of each stage
|
|
self.assertEqual(len(attentions), expected_num_attention_outputs)
|
|
|
|
out_len = len(outputs)
|
|
|
|
# Check attention is always last and order is fine
|
|
inputs_dict["output_attentions"] = True
|
|
inputs_dict["output_hidden_states"] = True
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
|
|
|
added_hidden_states = 1
|
|
self.assertEqual(out_len + added_hidden_states, len(outputs))
|
|
|
|
self_attentions = outputs.attentions
|
|
|
|
# GroupViT returns attention grouping of each stage
|
|
self.assertEqual(len(self_attentions), expected_num_attention_outputs)
|
|
for i, self_attn in enumerate(self_attentions):
|
|
if self_attn is None:
|
|
continue
|
|
|
|
self.assertListEqual(
|
|
list(self_attentions[i].shape[-2:]),
|
|
[
|
|
self.model_tester.num_output_groups[i],
|
|
self.model_tester.num_output_groups[i - 1] if i > 0 else seq_len,
|
|
],
|
|
)
|
|
|
|
def test_training(self):
|
|
pass
|
|
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING")
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTVisionModel has no base class and is not available in MODEL_MAPPING")
|
|
def test_save_load_fast_init_to_base(self):
|
|
pass
|
|
|
|
# override since the attention mask from GroupViT is not used to compute loss, thus no grad
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
config.output_hidden_states = True
|
|
config.output_attentions = self.has_attentions
|
|
|
|
# no need to test all models as different heads yield the same functionality
|
|
model_class = self.all_model_classes[0]
|
|
model = model_class(config)
|
|
model.to(torch_device)
|
|
|
|
inputs = self._prepare_for_class(inputs_dict, model_class)
|
|
|
|
outputs = model(**inputs)
|
|
|
|
output = outputs[0]
|
|
|
|
if config.is_encoder_decoder:
|
|
# Seq2Seq models
|
|
encoder_hidden_states = outputs.encoder_hidden_states[0]
|
|
encoder_hidden_states.retain_grad()
|
|
|
|
decoder_hidden_states = outputs.decoder_hidden_states[0]
|
|
decoder_hidden_states.retain_grad()
|
|
|
|
if self.has_attentions:
|
|
encoder_attentions = outputs.encoder_attentions[0]
|
|
encoder_attentions.retain_grad()
|
|
|
|
decoder_attentions = outputs.decoder_attentions[0]
|
|
decoder_attentions.retain_grad()
|
|
|
|
cross_attentions = outputs.cross_attentions[0]
|
|
cross_attentions.retain_grad()
|
|
|
|
output.flatten()[0].backward(retain_graph=True)
|
|
|
|
self.assertIsNotNone(encoder_hidden_states.grad)
|
|
self.assertIsNotNone(decoder_hidden_states.grad)
|
|
|
|
if self.has_attentions:
|
|
self.assertIsNotNone(encoder_attentions.grad)
|
|
self.assertIsNotNone(decoder_attentions.grad)
|
|
self.assertIsNotNone(cross_attentions.grad)
|
|
else:
|
|
# Encoder-/Decoder-only models
|
|
hidden_states = outputs.hidden_states[0]
|
|
hidden_states.retain_grad()
|
|
|
|
if self.has_attentions:
|
|
attentions = outputs.attentions[0]
|
|
attentions.retain_grad()
|
|
|
|
output.flatten()[0].backward(retain_graph=True)
|
|
|
|
self.assertIsNotNone(hidden_states.grad)
|
|
|
|
if self.has_attentions:
|
|
self.assertIsNone(attentions.grad)
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "nvidia/groupvit-gcc-yfcc"
|
|
model = GroupViTVisionModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
class GroupViTTextModelTester:
|
|
def __init__(
|
|
self,
|
|
parent,
|
|
batch_size=12,
|
|
seq_length=7,
|
|
is_training=True,
|
|
use_input_mask=True,
|
|
use_labels=True,
|
|
vocab_size=99,
|
|
hidden_size=32,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
intermediate_size=37,
|
|
dropout=0.1,
|
|
attention_dropout=0.1,
|
|
max_position_embeddings=512,
|
|
initializer_range=0.02,
|
|
scope=None,
|
|
):
|
|
self.parent = parent
|
|
self.batch_size = batch_size
|
|
self.seq_length = seq_length
|
|
self.is_training = is_training
|
|
self.use_input_mask = use_input_mask
|
|
self.use_labels = use_labels
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.num_hidden_layers = num_hidden_layers
|
|
self.num_attention_heads = num_attention_heads
|
|
self.intermediate_size = intermediate_size
|
|
self.dropout = dropout
|
|
self.attention_dropout = attention_dropout
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.initializer_range = initializer_range
|
|
self.scope = scope
|
|
|
|
def prepare_config_and_inputs(self):
|
|
rng = random.Random(0)
|
|
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size, rng=rng)
|
|
|
|
input_mask = None
|
|
if self.use_input_mask:
|
|
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
|
|
|
if input_mask is not None:
|
|
batch_size, seq_length = input_mask.shape
|
|
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
|
|
for batch_idx, start_index in enumerate(rnd_start_indices):
|
|
input_mask[batch_idx, :start_index] = 1
|
|
input_mask[batch_idx, start_index:] = 0
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, input_mask
|
|
|
|
def get_config(self):
|
|
return GroupViTTextConfig(
|
|
vocab_size=self.vocab_size,
|
|
hidden_size=self.hidden_size,
|
|
num_hidden_layers=self.num_hidden_layers,
|
|
num_attention_heads=self.num_attention_heads,
|
|
intermediate_size=self.intermediate_size,
|
|
dropout=self.dropout,
|
|
attention_dropout=self.attention_dropout,
|
|
max_position_embeddings=self.max_position_embeddings,
|
|
initializer_range=self.initializer_range,
|
|
)
|
|
|
|
def create_and_check_model(self, config, input_ids, input_mask):
|
|
model = GroupViTTextModel(config=config)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
result = model(input_ids, attention_mask=input_mask)
|
|
result = model(input_ids)
|
|
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
|
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_ids, input_mask = config_and_inputs
|
|
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class GroupViTTextModelTest(ModelTesterMixin, unittest.TestCase):
|
|
all_model_classes = (GroupViTTextModel,) if is_torch_available() else ()
|
|
test_pruning = False
|
|
test_head_masking = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = GroupViTTextModelTester(self)
|
|
self.config_tester = ConfigTester(self, config_class=GroupViTTextConfig, hidden_size=37)
|
|
|
|
def test_config(self):
|
|
self.config_tester.run_common_tests()
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
def test_training(self):
|
|
pass
|
|
|
|
def test_training_gradient_checkpointing(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant(self):
|
|
pass
|
|
|
|
@unittest.skip(
|
|
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
|
|
)
|
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTTextModel does not use inputs_embeds")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING")
|
|
def test_save_load_fast_init_from_base(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTTextModel has no base class and is not available in MODEL_MAPPING")
|
|
def test_save_load_fast_init_to_base(self):
|
|
pass
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "nvidia/groupvit-gcc-yfcc"
|
|
model = GroupViTTextModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
class GroupViTModelTester:
|
|
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
|
|
if text_kwargs is None:
|
|
text_kwargs = {}
|
|
if vision_kwargs is None:
|
|
vision_kwargs = {}
|
|
|
|
self.parent = parent
|
|
self.text_model_tester = GroupViTTextModelTester(parent, **text_kwargs)
|
|
self.vision_model_tester = GroupViTVisionModelTester(parent, **vision_kwargs)
|
|
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
|
self.is_training = is_training
|
|
|
|
def prepare_config_and_inputs(self):
|
|
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
|
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
|
|
|
config = self.get_config()
|
|
|
|
return config, input_ids, attention_mask, pixel_values
|
|
|
|
def get_config(self):
|
|
return GroupViTConfig.from_text_vision_configs(
|
|
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
|
)
|
|
|
|
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
|
model = GroupViTModel(config).to(torch_device).eval()
|
|
with torch.no_grad():
|
|
result = model(input_ids, pixel_values, attention_mask)
|
|
self.parent.assertEqual(
|
|
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
|
|
)
|
|
self.parent.assertEqual(
|
|
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
|
|
)
|
|
|
|
def prepare_config_and_inputs_for_common(self):
|
|
config_and_inputs = self.prepare_config_and_inputs()
|
|
config, input_ids, attention_mask, pixel_values = config_and_inputs
|
|
inputs_dict = {
|
|
"input_ids": input_ids,
|
|
"attention_mask": attention_mask,
|
|
"pixel_values": pixel_values,
|
|
"return_loss": True,
|
|
}
|
|
return config, inputs_dict
|
|
|
|
|
|
@require_torch
|
|
class GroupViTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
|
all_model_classes = (GroupViTModel,) if is_torch_available() else ()
|
|
pipeline_model_mapping = {"feature-extraction": GroupViTModel} if is_torch_available() else {}
|
|
test_head_masking = False
|
|
test_pruning = False
|
|
test_resize_embeddings = False
|
|
test_attention_outputs = False
|
|
|
|
def setUp(self):
|
|
self.model_tester = GroupViTModelTester(self)
|
|
|
|
def test_model(self):
|
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
|
|
|
@unittest.skip(reason="hidden_states are tested in individual model tests")
|
|
def test_hidden_states_output(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="input_embeds are tested in individual model tests")
|
|
def test_inputs_embeds(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="tested in individual model tests")
|
|
def test_retain_grad_hidden_states_attentions(self):
|
|
pass
|
|
|
|
@unittest.skip(reason="GroupViTModel does not have input/output embeddings")
|
|
def test_model_common_attributes(self):
|
|
pass
|
|
|
|
# overwritten from parent as this equivalent test needs a specific `seed` and hard to get a good one!
|
|
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-5, name="outputs", attributes=None):
|
|
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=tol, name=name, attributes=attributes)
|
|
|
|
@is_pt_tf_cross_test
|
|
def test_pt_tf_model_equivalence(self):
|
|
import tensorflow as tf
|
|
|
|
seed = 163
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
tf.random.set_seed(seed)
|
|
return super().test_pt_tf_model_equivalence()
|
|
|
|
# override as the `logit_scale` parameter initilization is different for GROUPVIT
|
|
def test_initialization(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
configs_no_init = _config_zero_init(config)
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
for name, param in model.named_parameters():
|
|
if param.requires_grad:
|
|
# check if `logit_scale` is initilized as per the original implementation
|
|
if name == "logit_scale":
|
|
self.assertAlmostEqual(
|
|
param.data.item(),
|
|
np.log(1 / 0.07),
|
|
delta=1e-3,
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
else:
|
|
self.assertIn(
|
|
((param.data.mean() * 1e9).round() / 1e9).item(),
|
|
[0.0, 1.0],
|
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
|
)
|
|
|
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
|
if not self.test_torchscript:
|
|
return
|
|
|
|
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
|
|
configs_no_init.torchscript = True
|
|
configs_no_init.return_dict = False
|
|
for model_class in self.all_model_classes:
|
|
model = model_class(config=configs_no_init)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
try:
|
|
input_ids = inputs_dict["input_ids"]
|
|
pixel_values = inputs_dict["pixel_values"] # GROUPVIT needs pixel_values
|
|
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
|
|
except RuntimeError:
|
|
self.fail("Couldn't trace module.")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
|
|
|
|
try:
|
|
torch.jit.save(traced_model, pt_file_name)
|
|
except Exception:
|
|
self.fail("Couldn't save module.")
|
|
|
|
try:
|
|
loaded_model = torch.jit.load(pt_file_name)
|
|
except Exception:
|
|
self.fail("Couldn't load module.")
|
|
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
loaded_model.to(torch_device)
|
|
loaded_model.eval()
|
|
|
|
model_state_dict = model.state_dict()
|
|
loaded_model_state_dict = loaded_model.state_dict()
|
|
|
|
non_persistent_buffers = {}
|
|
for key in loaded_model_state_dict.keys():
|
|
if key not in model_state_dict.keys():
|
|
non_persistent_buffers[key] = loaded_model_state_dict[key]
|
|
|
|
loaded_model_state_dict = {
|
|
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
|
|
}
|
|
|
|
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
|
|
|
|
model_buffers = list(model.buffers())
|
|
for non_persistent_buffer in non_persistent_buffers.values():
|
|
found_buffer = False
|
|
for i, model_buffer in enumerate(model_buffers):
|
|
if torch.equal(non_persistent_buffer, model_buffer):
|
|
found_buffer = True
|
|
break
|
|
|
|
self.assertTrue(found_buffer)
|
|
model_buffers.pop(i)
|
|
|
|
models_equal = True
|
|
for layer_name, p1 in model_state_dict.items():
|
|
p2 = loaded_model_state_dict[layer_name]
|
|
if p1.data.ne(p2.data).sum() > 0:
|
|
models_equal = False
|
|
|
|
self.assertTrue(models_equal)
|
|
|
|
def test_load_vision_text_config(self):
|
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
|
|
|
# Save GroupViTConfig and check if we can load GroupViTVisionConfig from it
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
config.save_pretrained(tmp_dir_name)
|
|
vision_config = GroupViTVisionConfig.from_pretrained(tmp_dir_name)
|
|
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
|
|
|
|
# Save GroupViTConfig and check if we can load GroupViTTextConfig from it
|
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
config.save_pretrained(tmp_dir_name)
|
|
text_config = GroupViTTextConfig.from_pretrained(tmp_dir_name)
|
|
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
|
|
|
@slow
|
|
def test_model_from_pretrained(self):
|
|
model_name = "nvidia/groupvit-gcc-yfcc"
|
|
model = GroupViTModel.from_pretrained(model_name)
|
|
self.assertIsNotNone(model)
|
|
|
|
|
|
# We will verify our results on an image of cute cats
|
|
def prepare_img():
|
|
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
im = Image.open(requests.get(url, stream=True).raw)
|
|
return im
|
|
|
|
|
|
@require_vision
|
|
@require_torch
|
|
class GroupViTModelIntegrationTest(unittest.TestCase):
|
|
@slow
|
|
def test_inference(self):
|
|
model_name = "nvidia/groupvit-gcc-yfcc"
|
|
model = GroupViTModel.from_pretrained(model_name)
|
|
processor = CLIPProcessor.from_pretrained(model_name)
|
|
|
|
image = prepare_img()
|
|
inputs = processor(
|
|
text=["a photo of a cat", "a photo of a dog"], images=image, padding=True, return_tensors="pt"
|
|
)
|
|
|
|
# forward pass
|
|
with torch.no_grad():
|
|
outputs = model(**inputs)
|
|
|
|
# verify the logits
|
|
self.assertEqual(
|
|
outputs.logits_per_image.shape,
|
|
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
|
|
)
|
|
self.assertEqual(
|
|
outputs.logits_per_text.shape,
|
|
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
|
|
)
|
|
|
|
expected_logits = torch.tensor([[13.3523, 6.3629]])
|
|
|
|
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
|