|
| 1 | +from math import isqrt |
| 2 | +from typing import Literal |
| 3 | + |
| 4 | +import torch |
| 5 | +from diff_surfel_rasterization import ( |
| 6 | + GaussianRasterizationSettings, |
| 7 | + GaussianRasterizer, |
| 8 | +) |
| 9 | +from einops import einsum, rearrange, repeat |
| 10 | +from jaxtyping import Float |
| 11 | +from torch import Tensor |
| 12 | + |
| 13 | +from ...geometry.projection import get_fov, homogenize_points |
| 14 | + |
| 15 | + |
| 16 | +def get_projection_matrix( |
| 17 | + near: Float[Tensor, " batch"], |
| 18 | + far: Float[Tensor, " batch"], |
| 19 | + fov_x: Float[Tensor, " batch"], |
| 20 | + fov_y: Float[Tensor, " batch"], |
| 21 | +) -> Float[Tensor, "batch 4 4"]: |
| 22 | + """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z |
| 23 | + axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after |
| 24 | + transformation and that Z is flipped. |
| 25 | + """ |
| 26 | + tan_fov_x = (0.5 * fov_x).tan() |
| 27 | + tan_fov_y = (0.5 * fov_y).tan() |
| 28 | + |
| 29 | + top = tan_fov_y * near |
| 30 | + bottom = -top |
| 31 | + right = tan_fov_x * near |
| 32 | + left = -right |
| 33 | + |
| 34 | + (b,) = near.shape |
| 35 | + result = torch.zeros((b, 4, 4), dtype=torch.float32, device=near.device) |
| 36 | + result[:, 0, 0] = 2 * near / (right - left) |
| 37 | + result[:, 1, 1] = 2 * near / (top - bottom) |
| 38 | + result[:, 0, 2] = (right + left) / (right - left) |
| 39 | + result[:, 1, 2] = (top + bottom) / (top - bottom) |
| 40 | + result[:, 3, 2] = 1 |
| 41 | + result[:, 2, 2] = far / (far - near) |
| 42 | + result[:, 2, 3] = -(far * near) / (far - near) |
| 43 | + return result |
| 44 | + |
| 45 | + |
| 46 | +def render_cuda_2dgs( |
| 47 | + extrinsics: Float[Tensor, "batch 4 4"], |
| 48 | + intrinsics: Float[Tensor, "batch 3 3"], |
| 49 | + near: Float[Tensor, " batch"], |
| 50 | + far: Float[Tensor, " batch"], |
| 51 | + image_shape: tuple[int, int], |
| 52 | + background_color: Float[Tensor, "batch 3"], |
| 53 | + gaussian_means: Float[Tensor, "batch gaussian 3"], |
| 54 | + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], |
| 55 | + gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], |
| 56 | + gaussian_opacities: Float[Tensor, "batch gaussian"], |
| 57 | + scale_invariant: bool = True, |
| 58 | + use_sh: bool = True, |
| 59 | + #cam_rot_delta: Float[Tensor, "batch 3"] | None = None, |
| 60 | + #cam_trans_delta: Float[Tensor, "batch 3"] | None = None, |
| 61 | +) -> tuple[Float[Tensor, "batch 3 height width"], Float[Tensor, "batch _ height width"]]: |
| 62 | + assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 |
| 63 | + |
| 64 | + # Make sure everything is in a range where numerical issues don't appear. |
| 65 | + if scale_invariant: |
| 66 | + scale = 1 / near |
| 67 | + extrinsics = extrinsics.clone() |
| 68 | + extrinsics[..., :3, 3] = extrinsics[..., :3, 3] * scale[:, None] |
| 69 | + gaussian_covariances = gaussian_covariances * (scale[:, None, None, None] ** 2) |
| 70 | + gaussian_means = gaussian_means * scale[:, None, None] |
| 71 | + near = near * scale |
| 72 | + far = far * scale |
| 73 | + |
| 74 | + _, _, _, n = gaussian_sh_coefficients.shape |
| 75 | + degree = isqrt(n) - 1 |
| 76 | + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() |
| 77 | + |
| 78 | + b, _, _ = extrinsics.shape |
| 79 | + h, w = image_shape |
| 80 | + |
| 81 | + fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) |
| 82 | + tan_fov_x = (0.5 * fov_x).tan() |
| 83 | + tan_fov_y = (0.5 * fov_y).tan() |
| 84 | + |
| 85 | + projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) |
| 86 | + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") |
| 87 | + view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") |
| 88 | + full_projection = view_matrix @ projection_matrix |
| 89 | + |
| 90 | + all_images = [] |
| 91 | + all_radii = [] |
| 92 | + all_maps = [] |
| 93 | + |
| 94 | + for i in range(b): |
| 95 | + print(f"cuda_splatting 2dgs i, b*v : {i}, {b}") |
| 96 | + # Set up a tensor for the gradients of the screen-space means. |
| 97 | + mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) |
| 98 | + try: |
| 99 | + mean_gradients.retain_grad() |
| 100 | + except Exception: |
| 101 | + pass |
| 102 | + |
| 103 | + settings = GaussianRasterizationSettings( |
| 104 | + image_height=h, |
| 105 | + image_width=w, |
| 106 | + tanfovx=tan_fov_x[i].item(), |
| 107 | + tanfovy=tan_fov_y[i].item(), |
| 108 | + bg=background_color[i], |
| 109 | + scale_modifier=1.0, |
| 110 | + viewmatrix=view_matrix[i], |
| 111 | + projmatrix=full_projection[i], |
| 112 | + #projmatrix_raw=projection_matrix[i], |
| 113 | + sh_degree=degree, |
| 114 | + campos=extrinsics[i, :3, 3], |
| 115 | + prefiltered=False, # This matches the original usage. |
| 116 | + debug=True, |
| 117 | + ) |
| 118 | + rasterizer = GaussianRasterizer(settings) |
| 119 | + |
| 120 | + row, col = torch.triu_indices(3, 3) |
| 121 | + |
| 122 | + image, radii, allmap = rasterizer( |
| 123 | + means3D=gaussian_means[i], |
| 124 | + means2D=mean_gradients, |
| 125 | + shs=shs[i] if use_sh else None, |
| 126 | + colors_precomp=None if use_sh else shs[i, :, 0, :], |
| 127 | + opacities=gaussian_opacities[i, ..., None], |
| 128 | + cov3D_precomp=gaussian_covariances[i, :, row, col], |
| 129 | + #theta=cam_rot_delta[i] if cam_rot_delta is not None else None, |
| 130 | + #rho=cam_trans_delta[i] if cam_trans_delta is not None else None, |
| 131 | + ) |
| 132 | + all_images.append(image) |
| 133 | + all_radii.append(radii) |
| 134 | + all_maps.append(allmap.squeeze(0)) |
| 135 | + return torch.stack(all_images), torch.stack(all_maps) |
| 136 | + |
| 137 | + |
| 138 | +def render_cuda_orthographic( |
| 139 | + extrinsics: Float[Tensor, "batch 4 4"], |
| 140 | + width: Float[Tensor, " batch"], |
| 141 | + height: Float[Tensor, " batch"], |
| 142 | + near: Float[Tensor, " batch"], |
| 143 | + far: Float[Tensor, " batch"], |
| 144 | + image_shape: tuple[int, int], |
| 145 | + background_color: Float[Tensor, "batch 3"], |
| 146 | + gaussian_means: Float[Tensor, "batch gaussian 3"], |
| 147 | + gaussian_covariances: Float[Tensor, "batch gaussian 3 3"], |
| 148 | + gaussian_sh_coefficients: Float[Tensor, "batch gaussian 3 d_sh"], |
| 149 | + gaussian_opacities: Float[Tensor, "batch gaussian"], |
| 150 | + fov_degrees: float = 0.1, |
| 151 | + use_sh: bool = True, |
| 152 | + dump: dict | None = None, |
| 153 | +) -> Float[Tensor, "batch 3 height width"]: |
| 154 | + b, _, _ = extrinsics.shape |
| 155 | + h, w = image_shape |
| 156 | + assert use_sh or gaussian_sh_coefficients.shape[-1] == 1 |
| 157 | + |
| 158 | + _, _, _, n = gaussian_sh_coefficients.shape |
| 159 | + degree = isqrt(n) - 1 |
| 160 | + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() |
| 161 | + |
| 162 | + # Create fake "orthographic" projection by moving the camera back and picking a |
| 163 | + # small field of view. |
| 164 | + fov_x = torch.tensor(fov_degrees, device=extrinsics.device).deg2rad() |
| 165 | + tan_fov_x = (0.5 * fov_x).tan() |
| 166 | + distance_to_near = (0.5 * width) / tan_fov_x |
| 167 | + tan_fov_y = 0.5 * height / distance_to_near |
| 168 | + fov_y = (2 * tan_fov_y).atan() |
| 169 | + near = near + distance_to_near |
| 170 | + far = far + distance_to_near |
| 171 | + move_back = torch.eye(4, dtype=torch.float32, device=extrinsics.device) |
| 172 | + move_back[2, 3] = -distance_to_near |
| 173 | + extrinsics = extrinsics @ move_back |
| 174 | + |
| 175 | + # Escape hatch for visualization/figures. |
| 176 | + if dump is not None: |
| 177 | + dump["extrinsics"] = extrinsics |
| 178 | + dump["fov_x"] = fov_x |
| 179 | + dump["fov_y"] = fov_y |
| 180 | + dump["near"] = near |
| 181 | + dump["far"] = far |
| 182 | + |
| 183 | + projection_matrix = get_projection_matrix( |
| 184 | + near, far, repeat(fov_x, "-> b", b=b), fov_y |
| 185 | + ) |
| 186 | + projection_matrix = rearrange(projection_matrix, "b i j -> b j i") |
| 187 | + view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") |
| 188 | + full_projection = view_matrix @ projection_matrix |
| 189 | + |
| 190 | + all_images = [] |
| 191 | + all_radii = [] |
| 192 | + for i in range(b): |
| 193 | + # Set up a tensor for the gradients of the screen-space means. |
| 194 | + mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) |
| 195 | + try: |
| 196 | + mean_gradients.retain_grad() |
| 197 | + except Exception: |
| 198 | + pass |
| 199 | + |
| 200 | + settings = GaussianRasterizationSettings( |
| 201 | + image_height=h, |
| 202 | + image_width=w, |
| 203 | + tanfovx=tan_fov_x, |
| 204 | + tanfovy=tan_fov_y, |
| 205 | + bg=background_color[i], |
| 206 | + scale_modifier=1.0, |
| 207 | + viewmatrix=view_matrix[i], |
| 208 | + projmatrix=full_projection[i], |
| 209 | + projmatrix_raw=projection_matrix[i], |
| 210 | + sh_degree=degree, |
| 211 | + campos=extrinsics[i, :3, 3], |
| 212 | + prefiltered=False, # This matches the original usage. |
| 213 | + debug=False, |
| 214 | + ) |
| 215 | + rasterizer = GaussianRasterizer(settings) |
| 216 | + |
| 217 | + row, col = torch.triu_indices(3, 3) |
| 218 | + |
| 219 | + image, radii, depth, opacity, n_touched = rasterizer( |
| 220 | + means3D=gaussian_means[i], |
| 221 | + means2D=mean_gradients, |
| 222 | + shs=shs[i] if use_sh else None, |
| 223 | + colors_precomp=None if use_sh else shs[i, :, 0, :], |
| 224 | + opacities=gaussian_opacities[i, ..., None], |
| 225 | + cov3D_precomp=gaussian_covariances[i, :, row, col], |
| 226 | + ) |
| 227 | + all_images.append(image) |
| 228 | + all_radii.append(radii) |
| 229 | + return torch.stack(all_images) |
| 230 | + |
| 231 | + |
| 232 | +DepthRenderingMode = Literal["depth", "disparity", "relative_disparity", "log"] |
0 commit comments