forked from z-fabian/flash-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_transforms.py
118 lines (100 loc) · 4.09 KB
/
data_transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import torch
import torchvision
from torchvision.transforms.functional import resize, center_crop
from scripts.utils import str2int, rescale_to_minusone_one
from data_utils.operators import create_operator, create_noise_schedule
class SevEncInputTransform:
def __init__(self):
self.transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def __call__(self, image):
image = self.transform(image)
return image.unsqueeze(0)
class ImageDataTransform:
def __init__(self,
is_train,
operator_schedule,
noise_schedule=None,
fixed_t=None,
t_range=None,
range_zero_one=False,
):
self.is_train = is_train
self.range_zero_one = range_zero_one
if isinstance(operator_schedule, dict):
self.fwd_operator = create_operator(operator_schedule)
else:
self.fwd_operator = operator_schedule
if noise_schedule is None:
self.noise_scheduler = None
elif isinstance(noise_schedule, dict):
self.noise_scheduler = create_noise_schedule(noise_schedule)
else:
self.noise_scheduler = noise_schedule
self.fixed_t = fixed_t
self.t_range = t_range
@torch.no_grad()
def __call__(self,
image,
fname=None
):
# Crop image to square
shorter = min(image.size)
image = center_crop(image, shorter)
# Resize images to uniform size
image = resize(image, (256, 256))
# Convert to ndarray and permute dimensions to C, H, W
image = np.array(image)
image = image.transpose(2, 0, 1)
# Normalize image to range [0, 1]
image = image / 255.
# Convert to tensor
image = torch.from_numpy(image.astype(np.float32))
image = image.unsqueeze(0)
if not self.is_train: # deterministic forward model for validation
assert fname is not None
seed = str2int(fname)
else:
seed = None
# Generate degraded noisy images
if self.fixed_t:
t = torch.tensor(self.fixed_t)
elif self.t_range is not None:
if not self.is_train:
g = torch.Generator()
g.manual_seed(seed)
t = torch.rand(1, generator=g) * (self.t_range[1] - self.t_range[0]) + self.t_range[0]
else:
t = torch.rand(1) * (self.t_range[1] - self.t_range[0]) + self.t_range[0]
else:
if not self.is_train:
g = torch.Generator()
g.manual_seed(seed)
t = torch.rand(1, generator=g)
else:
t = torch.rand(1)
degraded = self.fwd_operator(image, t, seed=seed).squeeze(0)
if self.noise_scheduler:
z, noise_std = self.noise_scheduler(t, image.shape, seed=seed)
degraded_noisy = degraded + z.to(image.device)
noisy = image + z.to(image.device)
else:
degraded_noisy = degraded
noisy = image
noise_std = 0.0
image = image.squeeze(0)
degraded_noisy = degraded_noisy.squeeze(0)
noisy = noisy.squeeze(0)
return {
'clean': image if self.range_zero_one else rescale_to_minusone_one(image),
'degraded': degraded if self.range_zero_one else rescale_to_minusone_one(degraded),
'degraded_noisy': degraded_noisy if self.range_zero_one else rescale_to_minusone_one(degraded_noisy),
'noise_std': noise_std,
'noisy': noisy if self.range_zero_one else rescale_to_minusone_one(noisy),
't': t,
'fname': fname,
}