147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
# coding=utf-8
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Copyright (c) HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from collections import Counter
|
|
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from PIL import Image
|
|
from torch import nn
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
|
|
|
|
|
|
class ImageEncoder(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
model = torchvision.models.resnet152(pretrained=True)
|
|
modules = list(model.children())[:-2]
|
|
self.model = nn.Sequential(*modules)
|
|
self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds])
|
|
|
|
def forward(self, x):
|
|
# Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
|
|
out = self.pool(self.model(x))
|
|
out = torch.flatten(out, start_dim=2)
|
|
out = out.transpose(1, 2).contiguous()
|
|
return out # BxNx2048
|
|
|
|
|
|
class JsonlDataset(Dataset):
|
|
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
|
|
self.data = [json.loads(l) for l in open(data_path)]
|
|
self.data_dir = os.path.dirname(data_path)
|
|
self.tokenizer = tokenizer
|
|
self.labels = labels
|
|
self.n_classes = len(labels)
|
|
self.max_seq_length = max_seq_length
|
|
|
|
self.transforms = transforms
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
|
|
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
|
|
sentence = sentence[: self.max_seq_length]
|
|
|
|
label = torch.zeros(self.n_classes)
|
|
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
|
|
|
|
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
|
|
image = self.transforms(image)
|
|
|
|
return {
|
|
"image_start_token": start_token,
|
|
"image_end_token": end_token,
|
|
"sentence": sentence,
|
|
"image": image,
|
|
"label": label,
|
|
}
|
|
|
|
def get_label_frequencies(self):
|
|
label_freqs = Counter()
|
|
for row in self.data:
|
|
label_freqs.update(row["label"])
|
|
return label_freqs
|
|
|
|
|
|
def collate_fn(batch):
|
|
lens = [len(row["sentence"]) for row in batch]
|
|
bsz, max_seq_len = len(batch), max(lens)
|
|
|
|
mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
|
text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
|
|
|
|
for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
|
|
text_tensor[i_batch, :length] = input_row["sentence"]
|
|
mask_tensor[i_batch, :length] = 1
|
|
|
|
img_tensor = torch.stack([row["image"] for row in batch])
|
|
tgt_tensor = torch.stack([row["label"] for row in batch])
|
|
img_start_token = torch.stack([row["image_start_token"] for row in batch])
|
|
img_end_token = torch.stack([row["image_end_token"] for row in batch])
|
|
|
|
return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor
|
|
|
|
|
|
def get_mmimdb_labels():
|
|
return [
|
|
"Crime",
|
|
"Drama",
|
|
"Thriller",
|
|
"Action",
|
|
"Comedy",
|
|
"Romance",
|
|
"Documentary",
|
|
"Short",
|
|
"Mystery",
|
|
"History",
|
|
"Family",
|
|
"Adventure",
|
|
"Fantasy",
|
|
"Sci-Fi",
|
|
"Western",
|
|
"Horror",
|
|
"Sport",
|
|
"War",
|
|
"Music",
|
|
"Musical",
|
|
"Animation",
|
|
"Biography",
|
|
"Film-Noir",
|
|
]
|
|
|
|
|
|
def get_image_transforms():
|
|
return transforms.Compose(
|
|
[
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(
|
|
mean=[0.46777044, 0.44531429, 0.40661017],
|
|
std=[0.12221994, 0.12145835, 0.14380469],
|
|
),
|
|
]
|
|
)
|