Skip to content

Commit ebdc87e

Browse files
committed
minors and a refined diff_pipeline tool
1 parent b114fec commit ebdc87e

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

lidiff/tools/diff_completion_pipeline.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
class DiffCompletion(LightningModule):
1616
def __init__(self, diff_path, refine_path, denoising_steps, cond_weight):
1717
super().__init__()
18-
hparams = yaml.safe_load(open(diff_path.split('checkpoints')[0] + '/hparams.yaml'))
19-
self.save_hyperparameters(hparams)
18+
ckpt_diff = torch.load(diff_path)
19+
self.save_hyperparameters(ckpt_diff['hyper_parameters'])
2020
assert denoising_steps <= self.hparams['diff']['t_steps'], \
2121
f"The number of denoising steps cannot be bigger than T={self.hparams['diff']['t_steps']} (you've set '-T {denoising_steps}')"
2222

23-
ckpt_diff = torch.load(diff_path)
2423
self.partial_enc = minknet.MinkGlobalEnc(in_channels=3, out_channels=self.hparams['model']['out_dim']).cuda()
2524
self.model = minknet.MinkUNetDiff(in_channels=3, out_channels=self.hparams['model']['out_dim']).cuda()
2625
self.model_refine = minknet.MinkUNet(in_channels=3, out_channels=3*6)
@@ -169,9 +168,17 @@ def completion_loop(self, x_init, x_t, x_cond, x_uncond):
169168

170169
return x_t.F.cpu().detach().numpy()
171170

171+
def load_pcd(pcd_file):
172+
if pcd_file.endswith('.bin'):
173+
return np.fromfile(pcd_file, dtype=np.float32).reshape((-1,4))[:,:3]
174+
elif pcd_file.endswith('.ply'):
175+
return np.array(o3d.io.read_point_cloud(pcd_file).points)
176+
else:
177+
print(f"Point cloud format '.{pcd_file.split('.')[-1]}' not supported. (supported formats: .bin (kitti format), .ply)")
178+
172179
@click.command()
173-
@click.option('--diff', '-d', type=str, default='', help='path to the scan sequence')
174-
@click.option('--refine', '-r', type=str, default='', help='path to the scan sequence')
180+
@click.option('--diff', '-d', type=str, default='checkpoints/diff_net.ckpt', help='path to the scan sequence')
181+
@click.option('--refine', '-r', type=str, default='checkpoints/refine_net.ckpt', help='path to the scan sequence')
175182
@click.option('--denoising_steps', '-T', type=int, default=50, help='number of denoising steps (default: 50)')
176183
@click.option('--cond_weight', '-s', type=float, default=6.0, help='conditioning weight (default: 6.0)')
177184
def main(diff, refine, denoising_steps, cond_weight):
@@ -188,8 +195,7 @@ def main(diff, refine, denoising_steps, cond_weight):
188195

189196
for pcd_path in tqdm.tqdm(natsorted(os.listdir(path))):
190197
pcd_file = os.path.join(path, pcd_path)
191-
input_pcd = o3d.io.read_point_cloud(pcd_file)
192-
points = np.array(input_pcd.points)
198+
points = load_pcd(pcd_file)
193199

194200
start = time.time()
195201
refine_scan, diff_scan = diff_completion.complete_scan(points)

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from setuptools import setup, find_packages
22

3-
pkg_name = 'pcdiff'
3+
pkg_name = 'lidiff'
44
setup(name=pkg_name, version='1.0', packages=find_packages())

0 commit comments

Comments
 (0)