-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
88 lines (65 loc) · 2.42 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import cv2
import logging
import transforms
from PIL import Image
from torch.utils.data import Dataset
logger = logging.getLogger('dataset')
class SHMDataset(Dataset):
def __init__(self, args, split='train'):
super().__init__()
self.image_dir = args.image_dir
self.matte_dir = args.matte_dir
self.trimap_dir = args.trimap_dir
self.patch_size = args.patch_size
self.mode = args.mode
self.split = split
self.files = []
self.create_transforms()
for name in os.listdir(args.image_dir):
self.files.append(name)
if split == 'train':
self.files = self.files[:90]
if split == 'test':
self.files = self.files[:10]
def create_transforms(self):
transforms_list = []
if self.mode == 'pretrain_tnet':
transforms_list.extend([
transforms.RandomCrop(400),
transforms.RandomRotation(180),
transforms.RandomHorizontalFlip()
])
if self.mode == 'pretrain_mnet':
transforms_list.extend([
transforms.RandomCrop(320),
])
if self.mode == 'end_to_end':
transforms_list.extend([
transforms.RandomCrop(800),
])
transforms_list.extend([
transforms.Resize((self.patch_size, self.patch_size)),
transforms.ToTensor()
])
self.transforms = transforms.Compose(transforms_list)
def __getitem__(self, index):
file_name = self.files[index]
image_path = os.path.join(self.image_dir, file_name)
image = Image.open(image_path)
instance = {'name': file_name}
if self.split == 'train':
trimap_path = os.path.join(self.trimap_dir, file_name).replace('.jpg', '.png')
matte_path = os.path.join(self.matte_dir, file_name).replace('.jpg', '.png')
trimap = Image.open(trimap_path).convert('L')
matte = Image.open(matte_path).convert('L')
[image, trimap, matte] = self.transforms([image, trimap, matte])
instance['image'] = image
instance['trimap'] = trimap
instance['matte'] = matte
else:
[image] = self.transforms([image])
instance['image'] = image
return instance
def __len__(self):
return len(self.files)