diff --git a/.github/workflows/core_tests.yml b/.github/workflows/core_tests.yml index 3e8e62176..7a06c3947 100644 --- a/.github/workflows/core_tests.yml +++ b/.github/workflows/core_tests.yml @@ -2,9 +2,9 @@ name: Core Tests. on: push: - branches: [master] + branches: [main] pull_request: - branches: [master] + branches: [main] permissions: contents: read diff --git a/diff_rast/_torch_impl.py b/diff_rast/_torch_impl.py index 009d24d51..95c2bb8d0 100644 --- a/diff_rast/_torch_impl.py +++ b/diff_rast/_torch_impl.py @@ -6,7 +6,9 @@ from torch import Tensor -def compute_sh_color(viewdirs: Float[Tensor, "*batch 3"], sh_coeffs: Float[Tensor, "*batch D C"]): +def compute_sh_color( + viewdirs: Float[Tensor, "*batch 3"], sh_coeffs: Float[Tensor, "*batch D C"] +): """ :param viewdirs (*, C) :param sh_coeffs (*, D, C) sh coefficients for each color channel @@ -66,7 +68,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): :return: torch.Tensor (..., basis_dim) """ - result = torch.empty((*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device) + result = torch.empty( + (*dirs.shape[:-1], basis_dim), dtype=dirs.dtype, device=dirs.device + ) result[..., 0] = SH_C0 if basis_dim > 1: x, y, z = dirs.unbind(-1) @@ -100,7 +104,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): result[..., 21] = SH_C4[5] * xz * (7 * zz - 3) result[..., 22] = SH_C4[6] * (xx - yy) * (7 * zz - 1) result[..., 23] = SH_C4[7] * xz * (xx - 3 * yy) - result[..., 24] = SH_C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) + result[..., 24] = SH_C4[8] * ( + xx * (xx - 3 * yy) - yy * (3 * xx - yy) + ) return result @@ -148,7 +154,9 @@ def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor return M @ M.transpose(-1, -2) # (..., 3, 3) -def project_cov3d_ewa(mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float) -> Tensor: +def project_cov3d_ewa( + mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float +) -> Tensor: assert mean3d.shape[-1] == 3, mean3d.shape assert cov3d.shape[-2:] == (3, 3), cov3d.shape assert viewmat.shape[-2:] == (4, 4), viewmat.shape @@ -212,7 +220,9 @@ def clip_near_plane(p, viewmat, thresh=0.1): def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16): - tile_size = torch.tensor([BLOCK_X, BLOCK_Y], dtype=torch.float32, device=pix_center.device) + tile_size = torch.tensor( + [BLOCK_X, BLOCK_Y], dtype=torch.float32, device=pix_center.device + ) tile_center = pix_center / tile_size tile_radius = pix_radius[..., None] / tile_size @@ -253,7 +263,9 @@ def project_gaussians_forward( conic, radius, det_valid = compute_cov2d_bounds(cov2d) center = project_pix(projmat, means3d, img_size) tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds) - tile_area = (tile_max[..., 0] - tile_min[..., 0]) * (tile_max[..., 1] - tile_min[..., 1]) + tile_area = (tile_max[..., 0] - tile_min[..., 0]) * ( + tile_max[..., 1] - tile_min[..., 1] + ) mask = (tile_area > 0) & (~is_close) & det_valid num_tiles_hit = tile_area diff --git a/diff_rast/cov2d_bounds.py b/diff_rast/cov2d_bounds.py index ebb32c0cf..0a946c612 100644 --- a/diff_rast/cov2d_bounds.py +++ b/diff_rast/cov2d_bounds.py @@ -29,7 +29,9 @@ def forward( ), f"Expected input cov2d to be of shape (*batch, 3) (upper triangular values), but got {tuple(cov2d.shape)}" num_pts = cov2d.shape[0] assert num_pts > 0 - conic, radius = _C.compute_cov2d_bounds_forward(num_pts, cov2d.contiguous().cuda()) + conic, radius = _C.compute_cov2d_bounds_forward( + num_pts, cov2d.contiguous().cuda() + ) return (conic, radius) @staticmethod diff --git a/diff_rast/project_gaussians.py b/diff_rast/project_gaussians.py index 13ad4990e..ad9e53a54 100644 --- a/diff_rast/project_gaussians.py +++ b/diff_rast/project_gaussians.py @@ -40,7 +40,7 @@ def forward( img_height: int, img_width: int, tile_bounds: Tuple[int, int, int], - clip_thresh:float=0.01 + clip_thresh: float = 0.01, ): num_points = means3d.shape[-2] @@ -102,13 +102,7 @@ def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d): conics, ) = ctx.saved_tensors - ( - v_cov2d, - v_cov3d, - v_mean3d, - v_scale, - v_quat, - ) = _C.project_gaussians_backward( + (v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat,) = _C.project_gaussians_backward( ctx.num_points, means3d, scales, diff --git a/tests/test_cov2d_bounds.py b/tests/test_cov2d_bounds.py index f1613dd40..aae9848bd 100644 --- a/tests/test_cov2d_bounds.py +++ b/tests/test_cov2d_bounds.py @@ -14,9 +14,16 @@ def compare_binding_to_pytorch(): num_cov2ds = 2 - _covs2d = torch.rand((num_cov2ds, 2, 2), dtype=torch.float32, device=device, requires_grad=True) + _covs2d = torch.rand( + (num_cov2ds, 2, 2), dtype=torch.float32, device=device, requires_grad=True + ) covs2d = torch.stack( - [torch.triu(_covs2d)[:, 0, 0], torch.triu(_covs2d)[:, 0, 1], torch.triu(_covs2d)[:, 1, 1]], dim=-1 + [ + torch.triu(_covs2d)[:, 0, 0], + torch.triu(_covs2d)[:, 0, 1], + torch.triu(_covs2d)[:, 1, 1], + ], + dim=-1, ) conic, radii = compute_cov2d_bounds.apply(covs2d)