-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
284 lines (238 loc) · 11.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from importlib import reload
import os
import utils.color as color
import utils.graph as graph
import utils.gft as gft
import utils.morton as morton
import utils.clustering as clustering
import utils.quantization as quantization
import utils.bitcoding as bitcoding
import utils.ply as ply
def run(method='kmeans', sequence='loot', frame_number=1, qsteps=[16], bsize=16, clusters_count=1500, lambd=0.2, colorspace='lab', qstep_centers=10, ref_iterations=1, ref_method='none', beta=2.0, bitstream_directory='/tmp', export_ply_blocks=False, decoder_match_test=False):
# initialize results array
results = np.empty((0, 11)) # 11 is the number of measurements we make
# check that we don't do refinement for lambda values of 0.0 (makes it easier to start jobs this way because of the zipping of arguments)
if lambd == 0.0 and ref_iterations > 0:
print(f"We don't continue if lambda is 0.0 and ref_iterations is bigger than 0")
return np.array([])
# process parameters
if sequence == 'loot':
start_frame = 999
elif sequence == 'longdress':
start_frame = 1050
elif sequence == 'soldier':
start_frame = 535
elif sequence == 'redandblack':
start_frame = 1449
frame = start_frame + frame_number
filename = f'/path/to/8iVFBv2/{sequence}/Ply/{sequence}_vox10_{frame:04d}.ply'
print(f'Processing sequence {sequence}, frame number {frame_number} of 300\n'
f'using {bitstream_directory} as temporary bitstream directory\n'
f'using parameters:\n'
f'quantize step sizes: {*qsteps,}\n'
f'method: {method}')
if method == 'octree':
print(f'block size: {bsize}')
elif method == 'kmeans':
print(f'clusters count: {clusters_count}\n'
f'lambda parameter: {lambd}\n'
f'colorspace: {colorspace}\n'
f'centers quantization step: {qstep_centers}\n'
f'refinement method: {ref_method}\n'
f'beta parameter: {beta}\n'
f'refinement iterations: {ref_iterations}\n')
if decoder_match_test:
print('Encoder/decoder match will be checked!\n')
## Encoder
### Load/prepare data
print(f'Loading pointcloud data...', end='', flush=True)
# load vertices and attributes
V, A_rgb, N, bitresolution = ply.read(filename)
print(f'loaded {N} points.')
# color conversion from RGB to YUV
A = color.rgb_to_yuv(A_rgb)
# check Morton order
# if morton.check_morton3D(x=V[:, 2], y=V[:, 1], z=V[:, 0]):
# print('Morton order correct')
### Partitioning
t = {}
print(f'Partitioning pointcloud...', end='', flush=True)
t['partitioning'] = time.time()
if (method == 'octree'):
idx_start, idx_stop, N_block = graph.block_indices(V=V, bsize=bsize)
V_ordered = V
A_ordered = A
elif (method == 'kmeans'):
emulate_dec = True
centers_encoder, labels_encoder, centers_decoder, labels_decoder, V_ordered, A_ordered, idx_start, idx_stop, N_block = clustering.kmeans_encoder(
V=V, A=A, A_rgb=A_rgb, partition_count=clusters_count, lam_part=lambd, colorspace=colorspace, qstep_centers=qstep_centers, emulate_dec=emulate_dec
)
t['partitioning'] = time.time() - t['partitioning']
print(f'done')
### Centers refinement
t['centers_ref'] = time.time()
if ref_method == 'none' or ref_iterations < 1:
print('No centers refinement will be performed.')
centers_decoder_refined = centers_decoder
else:
print('Refining centers...', end='', flush=True)
# make centers 6-dimensional for the VA method
if ref_method == 'VA':
centers_ref_VA_init = np.zeros((centers_decoder.shape[0], 6))
centers_ref_VA_init[:, 0:3] = centers_decoder
A_lab = color.rgb_to_lab(rgb=A_rgb)
for cluster_id in range(centers_decoder.shape[0]):
centers_ref_VA_init[cluster_id, 3:] = lambd * np.mean(A_lab[labels_decoder==cluster_id, :], axis=0)
centers_decoder = centers_ref_VA_init
if ref_method in ['VA', 'weight', 'weight1']:
centers_decoder_refined, labels_refined, V_ordered, A_ordered, idx_start, idx_stop, N_block, dist_refined = clustering.refine_centers(
V=V, A_yuv=A, A_rgb=A_rgb, centers_init=centers_decoder, labels_init=labels_decoder, N_iter=ref_iterations, method=ref_method, beta=beta, lam_part=lambd, scaling=False
)
print('done')
t['centers_ref'] = time.time() - t['centers_ref']
### Export PLY for blocks visualization
### GFT per block
print('Calculate GFT...')
t['gft'] = time.time()
Q = np.ones((N, 1))
Ahat, res, GFT_blocks, Gfreq_blocks = gft.transform_block_gft(
V=V_ordered, A=A_ordered, Q=Q, idx_start=idx_start, idx_stop=idx_stop)
t['gft'] = time.time() - t['gft']
for qstep in qsteps:
print(f'Processing for qstep={qstep}...')
### Quantize
print('Quantizing GFT coefficients...', end='', flush=True)
Ahat_quant, Ahat_quant_idx = quantization.quantize(x=Ahat, qstep=qstep)
print('done')
YPSNR_coeff = color.YPSNR(Ahat, Ahat_quant, N)
print('YPSNR|coeff={:2.4f} dB'.format(YPSNR_coeff))
### Sort
sort_method = 'dc_subgraphs' # 'none' 'dc'
print('Sorting coefficients...', end='', flush=True)
Ahat_quant_idx_sorted, mask_lo, mask_hi, num_subgraphs_blocks = gft.sort_block_gft_coeffs(
Ahat=Ahat_quant_idx, Gfreq_blocks=Gfreq_blocks,
idx_start=idx_start, idx_stop=idx_stop, N_block= N_block,
sort_method=sort_method)
print('done')
### Bit coding
print('Bit coding coefficients (and centers)...', end='', flush=True)
# encode the number of unused/duplicated centers if we used kmeans
bs_dupes = 0
if method == 'kmeans':
dupes_count = clusters_count - centers_decoder_refined.shape[0]
bs_dupes = bitcoding.write_number_to_file(
x=dupes_count, filename='dupes_count.bin', bitstream_directory=bitstream_directory
)
# encode quantized coeficients indices
bs_coeffs = bitcoding.code_YUV(Ahat_quant_idx_sorted, N=N, bitstream_directory=bitstream_directory)
# encode centers if using kmeans and if lambda is different from 0
bs_centers = 0
if method == 'kmeans' and lambd != 0.0:
# differential coding of the centers indices
_, centers_decoder_refined_idx_int = quantization.quantize(centers_decoder_refined, qstep_centers) # centers_decoder_refined is unsigned, np.diff() yields signed values
centers_decoder_refined_idx_diff = np.vstack((
centers_decoder_refined_idx_int[0, :], # save first entry
np.diff(centers_decoder_refined_idx_int, axis=0) # and then all the differences
))
bs_centers = bitcoding.encode_rlgr(
data=centers_decoder_refined_idx_diff.flatten('F').tolist(),
filename=os.path.join(bitstream_directory, 'bitstream_centers.bin'),
is_signed=1 # differences have a sign
)
bs_total = bs_coeffs + bs_centers + bs_dupes
print('done')
print('Sorted: Coded Y,U,V separately: rate={:2.4f} bits/symbol'.format(bs_total/N))
## Decoder
if decoder_match_test:
print(f'Encoder/decoder check...')
### Inverse GFT (necessary only for enc/dec match test)
if decoder_match_test:
print('Perform inverse GFT (for decoder check)...')
A_quant = gft.itransform_block_gft(
V=V_ordered, Ahat=Ahat_quant, Q=Q,
idx_start=idx_start, idx_stop=idx_stop, GFT_blocks=GFT_blocks)
_, A_quant_dec = decode(
V=V, N=N, method=method, qstep=qstep, bsize=bsize, clusters_count=clusters_count, lambd=lambd, colorspace=colorspace, qstep_centers=qstep_centers, bitstream_directory=bitstream_directory
)
### Encoder-decoder match check
print(f'Encoder-decoder match: { np.sum(np.abs(A_quant - A_quant_dec)) }')
# append the measurements for the current qstep to the results array
current_result = np.array([qstep, YPSNR_coeff, bs_total, bs_coeffs, bs_centers, bs_dupes, dupes_count, N, t['partitioning'], t['gft'], t['centers_ref']])
results = np.concatenate((results, np.array(current_result)[np.newaxis,:]))
return results
# default call when script is called on its own
if __name__ == "__main__":
run(method='kmeans')
def decode(
V, N, method='kmeans', qstep=16, bsize=16, clusters_count=1500, lambd=0.3, colorspace='lab', qstep_centers=10, bitstream_directory='/tmp'
):
### Bit decoding
print('Decoding coefficients and centers...', end='', flush=True)
if method == 'kmeans':
clusters_count_dec = clusters_count - bitcoding.get_number_from_file(
filename='dupes_count.bin', bitstream_directory=bitstream_directory
)
# decode quantization indices
Ahat_quant_idx_sorted_dec = bitcoding.decode_YUV(N, bitstream_directory)
# decode cluster_centers
if method == 'kmeans':
# Decode and reshape
centers_decoder_refined_idx_diff_dec = bitcoding.decode_rlgr(
filename=os.path.join(bitstream_directory, 'bitstream_centers.bin'), N=clusters_count_dec*3, is_signed=1
)
centers_decoder_refined_idx_diff_dec = centers_decoder_refined_idx_diff_dec.reshape((3, -1)).T
# Invert np.diff()
centers_decoder_refined_idx_int_dec = np.cumsum(centers_decoder_refined_idx_diff_dec, axis=0)
# dequantize
centers_decoder_refined_int_dec = quantization.dequantize(centers_decoder_refined_idx_int_dec, qstep_centers)
clusters_dec = centers_decoder_refined_int_dec.astype(np.uint32)
print('done')
### Partitioning
print('Partitioning pointcloud...', end='', flush=True)
if method == 'octree':
idx_start_dec, idx_stop_dec, N_block_dec = graph.block_indices(V=V, bsize=bsize)
elif method == 'kmeans':
labels_dec = clustering.labels_from_centers(
V=V, centers=clusters_dec
)
idx_start_dec, idx_stop_dec, N_block_dec, V_ordered_dec, _ = clustering.block_indices(
V=V, A=np.zeros_like(V), labels=labels_dec, clusters_count=clusters_count_dec
)
print('done')
### Create GFT block Matrices and frequencies
print('Calculate GFT...')
Q_dec = np.ones((N, 1))
_, _, GFT_blocks_dec, Gfreq_blocks_dec = gft.transform_block_gft(
V=V_ordered_dec, A=np.zeros_like(V), Q=Q_dec,
idx_start=idx_start_dec, idx_stop=idx_stop_dec, ret_GFT=True
)
### Inverse sorting
print('Inverse sort the coefficients...', end='', flush=True)
sort_method = 'dc_subgraphs' # 'none' 'dc'
mask_lo_dec, mask_hi_dec, _ = gft.create_sort_masks_subgraphs(
Gfreq_blocks=Gfreq_blocks_dec,
idx_start=idx_start_dec,
idx_stop=idx_stop_dec,
N_block=N_block_dec,
sort_method=sort_method
)
Ahat_quant_idx_dec = gft.reverse_sort_block_gft_coeffs(
Ahat_sort=Ahat_quant_idx_sorted_dec,
mask_lo=mask_lo_dec,
mask_hi=mask_hi_dec
)
print('done')
### Inverse quantization
print('Dequantize the coefficients...', end='', flush=True)
Ahat_quant_dec = quantization.dequantize(Ahat_quant_idx_dec, qstep)
print('done')
### Inverse GFT
print('Inverse GFT...')
A_quant_dec = gft.itransform_block_gft(
V=V_ordered_dec, Ahat=Ahat_quant_dec, Q=Q_dec,
idx_start=idx_start_dec, idx_stop=idx_stop_dec, GFT_blocks=GFT_blocks_dec)
return V_ordered_dec, A_quant_dec