forked from HAHA-DL/MLDG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_reader.py
92 lines (66 loc) · 2.8 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from pathlib import Path
import h5py
import numpy as np
from loguru import logger
from flags import Flags
from utils import shuffle_data, unfold_label
class BatchImageGenerator:
def __init__(self, flags: Flags, stage: str, file_path: Path, b_unfold_label: bool):
if stage not in ("train", "val", "test"):
raise ValueError("invalid stage!")
self.configuration(flags, stage, file_path)
self.load_data(b_unfold_label)
def configuration(self, flags: Flags, stage: str, file_path: Path):
self.batch_size = flags.batch_size
self.current_index = -1
self.file_path = file_path
self.stage = stage
self.shuffled = False
def normalize(self, inputs):
# the mean and std used for the normalization of
# the inputs for the pytorch pretrained model
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# norm to [0, 1]
inputs = inputs / 255.0
inputs_norm = []
for item in inputs:
item = np.transpose(item, (2, 0, 1))
item_norm = []
for c, m, s in zip(item, mean, std):
c = np.subtract(c, m)
c = np.divide(c, s)
item_norm.append(c)
item_norm = np.stack(item_norm)
inputs_norm.append(item_norm)
inputs_norm = np.stack(inputs_norm)
return inputs_norm
def load_data(self, b_unfold_label: bool):
file_path = self.file_path
with h5py.File(file_path) as f:
self.images = np.array(f["images"])
self.labels = np.array(f["labels"])
# shift the labels to start from 0
self.labels -= np.min(self.labels)
if b_unfold_label:
self.labels = unfold_label(labels=self.labels, classes=len(np.unique(self.labels)))
if len(self.images) != len(self.labels):
raise ValueError("the number of images must be equal to the number of labels")
self.num_labels = len(self.labels)
logger.info(f"{self.num_labels=}")
if self.stage == "train":
self.images, self.labels = shuffle_data(samples=self.images, labels=self.labels)
def get_images_labels_batch(self) -> tuple[np.ndarray, np.ndarray]:
images = []
labels = []
for index in range(self.batch_size):
self.current_index += 1
# void over flow
if self.current_index > self.num_labels - 1:
self.current_index %= self.num_labels
self.images, self.labels = shuffle_data(samples=self.images, labels=self.labels)
images.append(self.images[self.current_index])
labels.append(self.labels[self.current_index])
images = np.stack(images)
labels = np.stack(labels)
return images, labels