Skip to content

Commit 5b18993

Browse files
committed
2dgs splatter
1 parent e51a996 commit 5b18993

20 files changed

+763
-69
lines changed

config/experiment/re10k_1x8_ours.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ defaults:
44
- /dataset@_group_.re10k: re10k
55
- override /model/encoder: noposplat
66
- override /model/encoder/backbone: croco
7-
- override /model/decoder: gsplat_2dgs
7+
- override /model/decoder: splatting_cuda_2dgs #splatting_cuda | gsplat_cuda | splatting_cuda_2dgs
88
- override /loss: [mse, lpips]
99

1010
wandb:
@@ -33,7 +33,7 @@ optimizer:
3333

3434
data_loader:
3535
train:
36-
batch_size: 6
36+
batch_size: 1
3737

3838
trainer:
3939
max_steps: 200_001
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
name: gsplat_2dgs
1+
name: gsplat_cuda
22
background_color: null
33
make_scale_invariant: false
44
radius_clip: 0.0
55
eps2d: 0.3
66
sh_degree: 3
77
packed: true
88
tile_size: 16
9-
render_mode: RGB
9+
render_mode: RGB+D
1010
sparse_grad: false
1111
absgrad: false
1212
distloss: false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
name: splatting_cuda_2dgs
2+
background_color: [0.0, 0.0, 0.0]
3+
make_scale_invariant: false

diff-surfel-rasterization

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit e0ed0207b3e0669960cfad70852200a4a5847f61

pixi.lock

+250
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,12 @@ nvidia-dali-cuda120 = {version = "*"}
110110
[feature.deepspeed.pypi-dependencies]
111111
deepspeed = {version = "*"}
112112

113+
[feature.2dgs.pypi-dependencies]
114+
open3d = {version = ">=0.18.0,<=0.19.0"}
115+
mediapy = "*"
116+
117+
[feature.2dgs.tasks]
118+
install-gaussian-surfel-rasterizer = "git clone https://github.com/hbb1/diff-surfel-rasterization.git && pip install ./diff-surfel-rasterization"
119+
113120
[environments]
114-
npsplat = {features = ["py310", "torch241cu121", "ropebuild", "noposplat", "gsplat", "pycharm", "dali", "deepspeed"]}
121+
npsplat = {features = ["py310", "torch241cu121", "ropebuild", "noposplat", "gsplat", "pycharm", "dali", "deepspeed", "2dgs"]}

src/dataset/dataset_re10k.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,15 @@ def __iter__(self):
203203
},
204204
"scene": scene,
205205
}
206+
206207
if self.stage == "train" and self.cfg.augment:
207208
example = apply_augmentation_shim(example)
208-
yield apply_crop_shim(example, tuple(self.cfg.input_image_shape))
209+
210+
crop_example = apply_crop_shim(example, tuple(self.cfg.input_image_shape))
211+
212+
#print(f"\nBefore crop\n{example['target']['intrinsics'][0]} \n\nAfter crop\n{crop_example['target']['intrinsics'][0]}")
213+
#print(f"After crop\n{crop_example['target']['intrinsics'][0]}")
214+
yield crop_example
209215

210216
def convert_poses(
211217
self,

src/logger_setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from omegaconf import OmegaConf
22
from lightning.pytorch.loggers.wandb import WandbLogger
3+
from lightning.pytorch.utilities import rank_zero_only
34

4-
class WandbLoggerManager:
5+
class WandbLoggerManager():
56
_logger = None
67

78
@classmethod

src/model/decoder/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from .decoder import Decoder
22
from .decoder_splatting_cuda import DecoderSplattingCUDA, DecoderSplattingCUDACfg
3+
from .decoder_splatting_cuda_2dgs import DecoderSplattingCUDA2DGS, DecoderSplattingCUDA2DGSCfg
34
from .decoder_splatting_gsplat_cuda import DecoderGSplattingCUDA, DecoderGSplatting2DGSCfg
45

56
DECODERS = {
67
"splatting_cuda": DecoderSplattingCUDA,
7-
"gsplat_2dgs": DecoderGSplattingCUDA,
8+
"splatting_cuda_2dgs": DecoderSplattingCUDA2DGS,
9+
"gsplat_cuda": DecoderGSplattingCUDA,
810
}
911

10-
# DecoderCfg = DecoderSplattingCUDACfg | DecoderGSplatting2DGSCfg
11-
DecoderCfg = DecoderGSplatting2DGSCfg
12+
DecoderCfg = DecoderSplattingCUDACfg | DecoderSplattingCUDA2DGSCfg | DecoderGSplatting2DGSCfg
1213

1314
def get_decoder(decoder_cfg: DecoderCfg) -> Decoder:
1415
return DECODERS[decoder_cfg.name](decoder_cfg)

src/model/decoder/cuda_splatting.py

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def render_cuda(
9191
all_radii = []
9292
all_depths = []
9393
for i in range(b):
94+
print(f"cuda_splatting i, b*v : {i}, {b}")
9495
# Set up a tensor for the gradients of the screen-space means.
9596
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
9697
try:
+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
from math import isqrt
2+
from typing import Literal
3+
4+
import torch
5+
from diff_surfel_rasterization import (
6+
GaussianRasterizationSettings,
7+
GaussianRasterizer,
8+
)
9+
from einops import einsum, rearrange, repeat
10+
from jaxtyping import Float
11+
from torch import Tensor
12+
13+
from ...geometry.projection import get_fov, homogenize_points
14+
15+
16+
def get_projection_matrix(
17+
near: Float[Tensor, " batch"],
18+
far: Float[Tensor, " batch"],
19+
fov_x: Float[Tensor, " batch"],
20+
fov_y: Float[Tensor, " batch"],
21+
) -> Float[Tensor, "batch 4 4"]:
22+
"""Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
23+
axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
24+
transformation and that Z is flipped.
25+
"""
26+
tan_fov_x = (0.5 * fov_x).tan()
27+
tan_fov_y = (0.5 * fov_y).tan()
28+
29+
top = tan_fov_y * near
30+
bottom = -top
31+
right = tan_fov_x * near
32+
left = -right
33+
34+
(b,) = near.shape
35+
result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device)
36+
result[:, 0, 0] = 2 * near / (right - left)
37+
result[:, 1, 1] = 2 * near / (top - bottom)
38+
result[:, 0, 2] = (right + left) / (right - left)
39+
result[:, 1, 2] = (top + bottom) / (top - bottom)
40+
result[:, 3, 2] = 1
41+
result[:, 2, 2] = far / (far - near)
42+
result[:, 2, 3] = -(far * near) / (far - near)
43+
return result
44+
45+
46+
def render_cuda_2dgs(
47+
extrinsics: Float[Tensor, "batch 4 4"],
48+
intrinsics: Float[Tensor, "batch 3 3"],
49+
near: Float[Tensor, " batch"],
50+
far: Float[Tensor, " batch"],
51+
image_shape: tuple[int, int],
52+
background_color: Float[Tensor, "batch 3"],
53+
gaussian_means: Float[Tensor, "batch gaussian 3"],
54+
gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
55+
gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
56+
gaussian_opacities: Float[Tensor, "batch gaussian"],
57+
scale_invariant: bool = True,
58+
use_sh: bool = True,
59+
#cam_rot_delta: Float[Tensor, "batch 3"] | None = None,
60+
#cam_trans_delta: Float[Tensor, "batch 3"] | None = None,
61+
) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch _ height width"]]:
62+
assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
63+
64+
# Make sure everything is in a range where numerical issues don't appear.
65+
if scale_invariant:
66+
scale = 1 / near
67+
extrinsics = extrinsics.clone()
68+
extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None]
69+
gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2)
70+
gaussian_means = gaussian_means * scale[:, None, None]
71+
near = near * scale
72+
far = far * scale
73+
74+
_, _, _, n = gaussian_sh_coefficients.shape
75+
degree = isqrt(n) - 1
76+
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
77+
78+
b, _, _ = extrinsics.shape
79+
h, w = image_shape
80+
81+
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
82+
tan_fov_x = (0.5 * fov_x).tan()
83+
tan_fov_y = (0.5 * fov_y).tan()
84+
85+
projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
86+
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
87+
view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
88+
full_projection = view_matrix @ projection_matrix
89+
90+
all_images = []
91+
all_radii = []
92+
all_maps = []
93+
94+
for i in range(b):
95+
print(f"cuda_splatting 2dgs i, b*v : {i}, {b}")
96+
# Set up a tensor for the gradients of the screen-space means.
97+
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
98+
try:
99+
mean_gradients.retain_grad()
100+
except Exception:
101+
pass
102+
103+
settings = GaussianRasterizationSettings(
104+
image_height=h,
105+
image_width=w,
106+
tanfovx=tan_fov_x[i].item(),
107+
tanfovy=tan_fov_y[i].item(),
108+
bg=background_color[i],
109+
scale_modifier=1.0,
110+
viewmatrix=view_matrix[i],
111+
projmatrix=full_projection[i],
112+
#projmatrix_raw=projection_matrix[i],
113+
sh_degree=degree,
114+
campos=extrinsics[i, :3, 3],
115+
prefiltered=False, # This matches the original usage.
116+
debug=True,
117+
)
118+
rasterizer = GaussianRasterizer(settings)
119+
120+
row, col = torch.triu_indices(3, 3)
121+
122+
image, radii, allmap = rasterizer(
123+
means3D=gaussian_means[i],
124+
means2D=mean_gradients,
125+
shs=shs[i] if use_sh else None,
126+
colors_precomp=None if use_sh else shs[i, :, 0, :],
127+
opacities=gaussian_opacities[i, ..., None],
128+
cov3D_precomp=gaussian_covariances[i, :, row, col],
129+
#theta=cam_rot_delta[i] if cam_rot_delta is not None else None,
130+
#rho=cam_trans_delta[i] if cam_trans_delta is not None else None,
131+
)
132+
all_images.append(image)
133+
all_radii.append(radii)
134+
all_maps.append(allmap.squeeze(0))
135+
return torch.stack(all_images), torch.stack(all_maps)
136+
137+
138+
def render_cuda_orthographic(
139+
extrinsics: Float[Tensor, "batch 4 4"],
140+
width: Float[Tensor, " batch"],
141+
height: Float[Tensor, " batch"],
142+
near: Float[Tensor, " batch"],
143+
far: Float[Tensor, " batch"],
144+
image_shape: tuple[int, int],
145+
background_color: Float[Tensor, "batch 3"],
146+
gaussian_means: Float[Tensor, "batch gaussian 3"],
147+
gaussian_covariances: Float[Tensor, "batch gaussian 3 3"],
148+
gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"],
149+
gaussian_opacities: Float[Tensor, "batch gaussian"],
150+
fov_degrees: float = 0.1,
151+
use_sh: bool = True,
152+
dump: dict | None = None,
153+
) -> Float[Tensor, "batch 3 height width"]:
154+
b, _, _ = extrinsics.shape
155+
h, w = image_shape
156+
assert use_sh or gaussian_sh_coefficients.shape[-1] == 1
157+
158+
_, _, _, n = gaussian_sh_coefficients.shape
159+
degree = isqrt(n) - 1
160+
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
161+
162+
# Create fake "orthographic" projection by moving the camera back and picking a
163+
# small field of view.
164+
fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad()
165+
tan_fov_x = (0.5 * fov_x).tan()
166+
distance_to_near = (0.5 * width) / tan_fov_x
167+
tan_fov_y = 0.5 * height / distance_to_near
168+
fov_y = (2 * tan_fov_y).atan()
169+
near = near + distance_to_near
170+
far = far + distance_to_near
171+
move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device)
172+
move_back[2, 3] = -distance_to_near
173+
extrinsics = extrinsics @ move_back
174+
175+
# Escape hatch for visualization/figures.
176+
if dump is not None:
177+
dump["extrinsics"] = extrinsics
178+
dump["fov_x"] = fov_x
179+
dump["fov_y"] = fov_y
180+
dump["near"] = near
181+
dump["far"] = far
182+
183+
projection_matrix = get_projection_matrix(
184+
near, far, repeat(fov_x, "-> b", b=b), fov_y
185+
)
186+
projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
187+
view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
188+
full_projection = view_matrix @ projection_matrix
189+
190+
all_images = []
191+
all_radii = []
192+
for i in range(b):
193+
# Set up a tensor for the gradients of the screen-space means.
194+
mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
195+
try:
196+
mean_gradients.retain_grad()
197+
except Exception:
198+
pass
199+
200+
settings = GaussianRasterizationSettings(
201+
image_height=h,
202+
image_width=w,
203+
tanfovx=tan_fov_x,
204+
tanfovy=tan_fov_y,
205+
bg=background_color[i],
206+
scale_modifier=1.0,
207+
viewmatrix=view_matrix[i],
208+
projmatrix=full_projection[i],
209+
projmatrix_raw=projection_matrix[i],
210+
sh_degree=degree,
211+
campos=extrinsics[i, :3, 3],
212+
prefiltered=False, # This matches the original usage.
213+
debug=False,
214+
)
215+
rasterizer = GaussianRasterizer(settings)
216+
217+
row, col = torch.triu_indices(3, 3)
218+
219+
image, radii, depth, opacity, n_touched = rasterizer(
220+
means3D=gaussian_means[i],
221+
means2D=mean_gradients,
222+
shs=shs[i] if use_sh else None,
223+
colors_precomp=None if use_sh else shs[i, :, 0, :],
224+
opacities=gaussian_opacities[i, ..., None],
225+
cov3D_precomp=gaussian_covariances[i, :, row, col],
226+
)
227+
all_images.append(image)
228+
all_radii.append(radii)
229+
return torch.stack(all_images)
230+
231+
232+
DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"]

src/model/decoder/decoder.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
@dataclass
19-
class DecoderOutput:
19+
class DecoderOutput:
2020
color: Float[Tensor, "batch view 3 height width"]
2121
depth: Float[Tensor, "batch view height width"] | None
2222

@@ -41,5 +41,6 @@ def forward(
4141
far: Float[Tensor, "batch view"],
4242
image_shape: tuple[int, int],
4343
depth_mode: DepthRenderingMode | None = None,
44+
global_step: int | None = None,
4445
) -> DecoderOutput:
4546
pass

src/model/decoder/decoder_splatting_cuda.py

+11
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .cuda_splatting import DepthRenderingMode, render_cuda
1212
from .decoder import Decoder, DecoderOutput
1313

14+
from src.logger_setup import WandbLoggerManager
1415

1516
@dataclass
1617
class DecoderSplattingCUDACfg:
@@ -45,8 +46,12 @@ def forward(
4546
depth_mode: DepthRenderingMode | None = None,
4647
cam_rot_delta: Float[Tensor, "batch view 3"] | None = None,
4748
cam_trans_delta: Float[Tensor, "batch view 3"] | None = None,
49+
global_step: int | None = None,
4850
) -> DecoderOutput:
51+
wandb_logger = WandbLoggerManager.get_logger()
52+
4953
b, v, _, _ = extrinsics.shape
54+
5055
color, depth = render_cuda(
5156
rearrange(extrinsics, "b v i j -> (b v) i j"),
5257
rearrange(intrinsics, "b v i j -> (b v) i j"),
@@ -64,5 +69,11 @@ def forward(
6469
)
6570
color = rearrange(color, "(b v) c h w -> b v c h w", b=b, v=v)
6671

72+
wandb_logger.log_image(
73+
"rasterized output of first batch",
74+
[color[0, i] for i in range(color.shape[1])],
75+
step=global_step,
76+
)
77+
6778
depth = rearrange(depth, "(b v) h w -> b v h w", b=b, v=v)
6879
return DecoderOutput(color, depth)

0 commit comments

Comments
 (0)