@@ -1512,23 +1512,37 @@ def rasterization_2dgs(
15121512 # if packed:
15131513 # colors = colors.view(B, C, N, -1)[batch_ids, camera_ids, gaussian_ids, :]
15141514 if sh_degree is not None : # SH coefficients
1515- camtoworlds = torch .inverse (viewmats )
1515+ # Colors are SH coefficients, with shape [..., N, K, 3] or [..., C, N, K, 3]
1516+ campos = torch .inverse (viewmats )[..., :3 , 3 ] # [..., C, 3]
15161517 if packed :
1517- dirs = means [..., gaussian_ids , :] - camtoworlds [..., camera_ids , :3 , 3 ]
1518- else :
1519- dirs = means [..., None , :, :] - camtoworlds [..., None , :3 , 3 ]
1520-
1521- if colors .dim () == num_batch_dims + 3 :
1522- # Turn [..., N, K, 3] into [..., C, N, K, 3]
1523- shs = torch .broadcast_to (
1524- colors [..., None , :, :, :], batch_dims + (C , N , - 1 , 3 )
1525- ) # [..., C, N, K, 3]
1518+ dirs = (
1519+ means .view (B , N , 3 )[batch_ids , gaussian_ids ]
1520+ - campos .view (B , C , 3 )[batch_ids , camera_ids ]
1521+ ) # [nnz, 3]
1522+ masks = (radii > 0 ).all (dim = - 1 ) # [nnz]
1523+ if colors .dim () == num_batch_dims + 3 :
1524+ # Turn [..., N, K, 3] into [nnz, 3]
1525+ shs = colors .view (B , N , - 1 , 3 )[batch_ids , gaussian_ids ] # [nnz, K, 3]
1526+ else :
1527+ # Turn [..., C, N, K, 3] into [nnz, 3]
1528+ shs = colors .view (B , C , N , - 1 , 3 )[
1529+ batch_ids , camera_ids , gaussian_ids
1530+ ] # [nnz, K, 3]
1531+ colors = spherical_harmonics (sh_degree , dirs , shs , masks = masks ) # [nnz, 3]
15261532 else :
1527- # colors is already [..., C, N, K, 3]
1528- shs = colors
1529- colors = spherical_harmonics (
1530- sh_degree , dirs , shs , masks = (radii > 0 ).all (dim = - 1 )
1531- ) # [nnz, D] or [..., C, N, 3]
1533+ dirs = means [..., None , :, :] - campos [..., None , :] # [..., C, N, 3]
1534+ masks = (radii > 0 ).all (dim = - 1 ) # [..., C, N]
1535+ if colors .dim () == num_batch_dims + 3 :
1536+ # Turn [..., N, K, 3] into [..., C, N, K, 3]
1537+ shs = torch .broadcast_to (
1538+ colors [..., None , :, :, :], batch_dims + (C , N , - 1 , 3 )
1539+ )
1540+ else :
1541+ # colors is already [..., C, N, K, 3]
1542+ shs = colors
1543+ colors = spherical_harmonics (
1544+ sh_degree , dirs , shs , masks = masks
1545+ ) # [..., C, N, 3]
15321546 # make it apple-to-apple with Inria's CUDA Backend.
15331547 colors = torch .clamp_min (colors + 0.5 , 0.0 )
15341548
0 commit comments