51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
import glob
|
|
import random
|
|
import os
|
|
import numpy as np
|
|
|
|
from jittor.dataset.dataset import Dataset
|
|
import jittor.transform as transform
|
|
from PIL import Image
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, root, transforms_=None, img_size=128, mask_size=64, mode="train"):
|
|
super().__init__()
|
|
self.transform = transform.Compose(transforms_)
|
|
self.img_size = img_size
|
|
self.mask_size = mask_size
|
|
self.mode = mode
|
|
self.files = sorted(glob.glob("%s/*.jpg" % root))
|
|
self.files = self.files[:-4000] if mode == "train" else self.files[-4000:]
|
|
self.set_attrs(total_len=len(self.files))
|
|
|
|
def apply_random_mask(self, img):
|
|
"""Randomly masks image"""
|
|
y1, x1 = np.random.randint(0, self.img_size - self.mask_size, 2)
|
|
y2, x2 = y1 + self.mask_size, x1 + self.mask_size
|
|
masked_part = img[:, y1:y2, x1:x2]
|
|
masked_img = img.copy()
|
|
masked_img[:, y1:y2, x1:x2] = 1
|
|
|
|
return masked_img, masked_part
|
|
|
|
def apply_center_mask(self, img):
|
|
"""Mask center part of image"""
|
|
# Get upper-left pixel coordinate
|
|
i = (self.img_size - self.mask_size) // 2
|
|
masked_img = img.copy()
|
|
masked_img[:, i : i + self.mask_size, i : i + self.mask_size] = 1
|
|
|
|
return masked_img, i
|
|
|
|
def __getitem__(self, index):
|
|
img = Image.open(self.files[index % len(self.files)])
|
|
img = self.transform(img)
|
|
if self.mode == "train":
|
|
# For training data perform random mask
|
|
masked_img, aux = self.apply_random_mask(img)
|
|
else:
|
|
# For test data mask the center of the image
|
|
masked_img, aux = self.apply_center_mask(img)
|
|
|
|
return img, masked_img, aux |