-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtest_BRN.py
114 lines (84 loc) · 3.45 KB
/
test_BRN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import cv2
import os
import argparse
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from utils import *
from generator import BRN, print_network
import time
parser = argparse.ArgumentParser(description="BRN_Test")
parser.add_argument("--logdir", type=str, default="logs/dataset/BRN", help='path of log files')
parser.add_argument("--data_path", type=str, default="dataset/...", help='path to testing data')
parser.add_argument("--save_path", type=str, default="results/dataset/BRN", help='path to save results')
parser.add_argument("--use_GPU", type=bool, default=True, help='use GPU or not')
parser.add_argument("--gpu_id", type=str, default="0", help='GPU id')
parser.add_argument("--inter_iter", type=int, default=8, help='number of inter_iteration')
opt = parser.parse_args()
if opt.use_GPU:
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu_id
def main():
if not os.path.isdir(opt.save_path):
os.makedirs(opt.save_path)
# Build model
print('Loading model ...\n')
model = BRN(opt.inter_iter, opt.use_GPU)
print_network(model)
if opt.use_GPU:
model = model.cuda()
state_dict = torch.load(os.path.join(opt.logdir, 'net_latest.pth'))
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
#model.load_state_dict(torch.load(os.path.join(opt.logdir, 'net_latest.pth')))
model.eval()
# load data info
print('Loading data info ...\n')
# process data
time_test = 0
count = 0
for img_name in os.listdir(opt.data_path):
if is_image(img_name):
img_path = os.path.join(opt.data_path, img_name)
# image
Img = cv2.imread(img_path)
h, w, c = Img.shape
b, g, r = cv2.split(Img)
Img = cv2.merge([r, g, b])
Img = normalize(np.float32(Img))
Img = np.expand_dims(Img.transpose(2, 0, 1), 0)
ISource = torch.Tensor(Img)
INoisy = ISource
if opt.use_GPU:
ISource, INoisy = Variable(ISource.cuda()), Variable(INoisy.cuda())
else:
ISource, INoisy = Variable(ISource), Variable(INoisy)
with torch.no_grad(): # this can save much memory
torch.cuda.synchronize()
start_time = time.time()
out, _, _, _ = model(INoisy)
out = torch.clamp(out, 0., 1.)
torch.cuda.synchronize()
end_time = time.time()
dur_time = end_time - start_time
print(img_name)
print(dur_time)
time_test += dur_time
if opt.use_GPU:
save_out = np.uint8(255 * out.data.cpu().numpy().squeeze())
else:
save_out = np.uint8(255 * out.data.numpy().squeeze())
save_out = save_out.transpose(1, 2, 0)
b, g, r = cv2.split(save_out)
save_out = cv2.merge([r, g, b])
save_path = opt.save_path
cv2.imwrite(os.path.join(save_path, img_name), save_out)
count = count + 1
print('Avg. time:', time_test/count)
if __name__ == "__main__":
main()