Skip to content

Commit 87d6613

Browse files
committed
first commit
0 parents  commit 87d6613

30 files changed

+64267
-0
lines changed

Datahelper2.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import pandas as pd
2+
from torch import np # Torch wrapper for Numpy
3+
4+
import os
5+
from PIL import Image
6+
import h5py
7+
import random
8+
import torch
9+
from torch.utils.data.dataset import Dataset
10+
from torch.utils.data import DataLoader
11+
from torchvision import transforms
12+
import pickle
13+
14+
class AmazonDateset_train(Dataset):
15+
def __init__(self, train_index, img_path, img_ext,label_path,resize=None):
16+
super(AmazonDateset_train, self).__init__()
17+
self.img_path = img_path
18+
self.img_ext = img_ext
19+
if resize != 256:
20+
self.transform = transforms.Compose([transforms.Scale(resize),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
21+
else:
22+
self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
23+
24+
25+
self.img_index = train_index
26+
self.label = pickle.load(open(label_path,'rb'))
27+
28+
29+
def __getitem__(self, index):
30+
img_index = index//8
31+
tft = index%8
32+
img = Image.open(self.img_path + 'train_'+str(self.img_index[img_index]) + self.img_ext)
33+
if tft >= 4:
34+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
35+
r = tft % 4
36+
R = [None, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270][r]
37+
if R != None:
38+
img = img.transpose(R)
39+
40+
img = img.convert('RGB')
41+
img = self.transform(img)
42+
label = torch.from_numpy(self.label['train_'+str(self.img_index[img_index])]).float()
43+
return img, label
44+
45+
def __len__(self):
46+
return len(self.img_index)*8
47+
48+
49+
class AmazonDateset_validate(Dataset):
50+
def __init__(self, validate_index, img_path, img_ext,label_path,transform_type=0,random_transform=False,resize=None):
51+
super(AmazonDateset_validate, self).__init__()
52+
self.img_path = img_path
53+
self.img_ext = img_ext
54+
self.transform_type = transform_type
55+
self.random_transform = random_transform
56+
if resize != 256:
57+
self.transform = transforms.Compose([transforms.Scale(resize),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
58+
else:
59+
self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
60+
61+
self.img_index = validate_index
62+
self.label = pickle.load(open(label_path,'rb'))
63+
64+
def __getitem__(self, index):
65+
img = Image.open(self.img_path + 'train_'+str(self.img_index[index]) + self.img_ext)
66+
if self.random_transform:
67+
tft = random.randint(0, 7)
68+
else:
69+
tft = self.transform_type
70+
if tft >= 4:
71+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
72+
r = tft % 4
73+
R = [None, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270][r]
74+
if R != None:
75+
img = img.transpose(R)
76+
77+
img = img.convert('RGB')
78+
img = self.transform(img)
79+
label = torch.from_numpy(self.label['train_'+str(self.img_index[index])]).float()
80+
return img, label
81+
82+
def __len__(self):
83+
return len(self.img_index)
84+
85+
class KaggleAmazonDataset_test(Dataset):
86+
87+
def __init__(self, img_path,transform_type=0,resize=None):
88+
89+
self.img_dir = img_path
90+
self.img_list = os.listdir(img_path)
91+
self.transform_type = transform_type
92+
if resize != 256:
93+
self.transform = transforms.Compose([transforms.Scale(resize),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
94+
else:
95+
self.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
96+
97+
def __getitem__(self, index):
98+
img = Image.open(self.img_dir + self.img_list[index])
99+
100+
tft = self.transform_type # transform_type
101+
if tft >= 4:
102+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
103+
r = tft % 4
104+
R = [None, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270][r]
105+
if R != None:
106+
img = img.transpose(R)
107+
108+
img = img.convert('RGB')
109+
img = self.transform(img)
110+
111+
return img,self.img_list[index].split('.')[0]
112+
113+
def __len__(self):
114+
return len(self.img_list)
115+
116+
117+
118+
119+
if __name__=='__main__':
120+
DS = '/home/kyle/PythonProject/Amazon/train_validate_dataset.h5'
121+
# IMG_TRAIN_PATH = '/home/jianglibin/PythonProject/Amazon/data/train-jpg/'
122+
CSV_PATH = '/home/kyle/PythonProject/AmazonData/train_v2.csv'
123+
IMG_PATH = '/home/kyle/PythonProject/AmazonData/train-jpg/'
124+
IMG_EXT = '.jpg'
125+
LABEL_PATH = '/home/kyle/PythonProject/Amazon/labels.h5'
126+
127+
IMG_TEST_PATH = '/home/kyle/PythonProject/AmazonData/test-jpg/'
128+
129+

0 commit comments

Comments
 (0)