15
15
class DiffCompletion (LightningModule ):
16
16
def __init__ (self , diff_path , refine_path , denoising_steps , cond_weight ):
17
17
super ().__init__ ()
18
- hparams = yaml . safe_load ( open ( diff_path . split ( 'checkpoints' )[ 0 ] + '/hparams.yaml' ) )
19
- self .save_hyperparameters (hparams )
18
+ ckpt_diff = torch . load ( diff_path )
19
+ self .save_hyperparameters (ckpt_diff [ 'hyper_parameters' ] )
20
20
assert denoising_steps <= self .hparams ['diff' ]['t_steps' ], \
21
21
f"The number of denoising steps cannot be bigger than T={ self .hparams ['diff' ]['t_steps' ]} (you've set '-T { denoising_steps } ')"
22
22
23
- ckpt_diff = torch .load (diff_path )
24
23
self .partial_enc = minknet .MinkGlobalEnc (in_channels = 3 , out_channels = self .hparams ['model' ]['out_dim' ]).cuda ()
25
24
self .model = minknet .MinkUNetDiff (in_channels = 3 , out_channels = self .hparams ['model' ]['out_dim' ]).cuda ()
26
25
self .model_refine = minknet .MinkUNet (in_channels = 3 , out_channels = 3 * 6 )
@@ -169,9 +168,17 @@ def completion_loop(self, x_init, x_t, x_cond, x_uncond):
169
168
170
169
return x_t .F .cpu ().detach ().numpy ()
171
170
171
+ def load_pcd (pcd_file ):
172
+ if pcd_file .endswith ('.bin' ):
173
+ return np .fromfile (pcd_file , dtype = np .float32 ).reshape ((- 1 ,4 ))[:,:3 ]
174
+ elif pcd_file .endswith ('.ply' ):
175
+ return np .array (o3d .io .read_point_cloud (pcd_file ).points )
176
+ else :
177
+ print (f"Point cloud format '.{ pcd_file .split ('.' )[- 1 ]} ' not supported. (supported formats: .bin (kitti format), .ply)" )
178
+
172
179
@click .command ()
173
- @click .option ('--diff' , '-d' , type = str , default = '' , help = 'path to the scan sequence' )
174
- @click .option ('--refine' , '-r' , type = str , default = '' , help = 'path to the scan sequence' )
180
+ @click .option ('--diff' , '-d' , type = str , default = 'checkpoints/diff_net.ckpt ' , help = 'path to the scan sequence' )
181
+ @click .option ('--refine' , '-r' , type = str , default = 'checkpoints/refine_net.ckpt ' , help = 'path to the scan sequence' )
175
182
@click .option ('--denoising_steps' , '-T' , type = int , default = 50 , help = 'number of denoising steps (default: 50)' )
176
183
@click .option ('--cond_weight' , '-s' , type = float , default = 6.0 , help = 'conditioning weight (default: 6.0)' )
177
184
def main (diff , refine , denoising_steps , cond_weight ):
@@ -188,8 +195,7 @@ def main(diff, refine, denoising_steps, cond_weight):
188
195
189
196
for pcd_path in tqdm .tqdm (natsorted (os .listdir (path ))):
190
197
pcd_file = os .path .join (path , pcd_path )
191
- input_pcd = o3d .io .read_point_cloud (pcd_file )
192
- points = np .array (input_pcd .points )
198
+ points = load_pcd (pcd_file )
193
199
194
200
start = time .time ()
195
201
refine_scan , diff_scan = diff_completion .complete_scan (points )
0 commit comments