66import b3d .utils as utils
77import b3d .chisight .dense .differentiable_renderer as rendering
88import rerun as rr
9+ import numpy as np
910
1011def uniformpose_meshes_to_image_model__factory (likelihood ):
1112 """
@@ -77,15 +78,15 @@ def rr_log_uniformpose_meshes_to_image_model_trace(trace, renderer, **kwargs):
7778 """
7879 return rr_log_meshes_to_image_model_trace (trace , renderer , ** kwargs ,
7980 model_args_to_densemodel_args = (
80- lambda args : (trace .get_choices ()["camera_pose" ], trace .get_choices ()[ "poses" ] , * args )
81+ lambda args : (trace .get_choices ()["camera_pose" ], trace .get_choices ()( "poses" ). c . v , * args )
8182 ))
8283
8384def rr_log_meshes_to_image_model_trace (
8485 trace , renderer ,
8586 prefix = "trace" ,
8687 timeless = False ,
8788 model_args_to_densemodel_args = (lambda x : x ),
88- transform = Pose .identity ()
89+ transform_Viz_Trace = Pose .identity ()
8990 ):
9091 """
9192 Log to rerun a visualization of a trace from `meshes_to_image_model`.
@@ -94,11 +95,16 @@ def rr_log_meshes_to_image_model_trace(
9495 to visualize traces from other models that have the same return value as `meshes_to_image_model`.
9596 This function will call `model_args_to_densemodel_args` on the arguments of the given trace,
9697 and should produce arguments of the form accepted by `meshes_to_image_model`.
98+
99+ The argument `transform_Viz_Trace` can be used to visualize the trace at a transformed
100+ coordinate frame. `transform_Viz_Trace` is a Pose object so that for a 3D point
101+ `point_Trace` in the trace, `transform_Viz_Trace.apply(point_Trace)` is the corresponding
102+ 3D point in the visualizer.
97103 """
98104 # 2D:
99105 (observed_rgbd , metadata ) = trace .get_retval ()
100- rr .log (f"/{ prefix } /rgb/observed" , rr .Image (observed_rgbd [:, :, :3 ]), timeless = timeless )
101- rr .log (f"/{ prefix } /depth/observed" , rr .DepthImage (observed_rgbd [:, :, 3 ]), timeless = timeless )
106+ rr .log (f"/{ prefix } /rgb/observed" , rr .Image (np . array ( observed_rgbd [:, :, :3 ]) ), timeless = timeless )
107+ rr .log (f"/{ prefix } /depth/observed" , rr .DepthImage (np . array ( observed_rgbd [:, :, 3 ]) ), timeless = timeless )
102108
103109 # Visualization path for the average render,
104110 # if the likelihood metadata contains the output of the differentiable renderer.
@@ -107,46 +113,45 @@ def rr_log_meshes_to_image_model_trace(
107113 avg_obs = rendering .dist_params_to_average (weights , attributes , jnp .zeros (4 ))
108114 avg_obs_rgb_clipped = jnp .clip (avg_obs [:, :, :3 ], 0 , 1 )
109115 avg_obs_depth_clipped = jnp .clip (avg_obs [:, :, 3 ], 0 , 1 )
110- rr .log (f"/{ prefix } /rgb/average_render" , rr .Image (avg_obs_rgb_clipped ), timeless = timeless )
111- rr .log (f"/{ prefix } /depth/average_render" , rr .DepthImage (avg_obs_depth_clipped ), timeless = timeless )
116+ rr .log (f"/{ prefix } /rgb/average_render" , rr .Image (np . array ( avg_obs_rgb_clipped ) ), timeless = timeless )
117+ rr .log (f"/{ prefix } /depth/average_render" , rr .DepthImage (np . array ( avg_obs_depth_clipped ) ), timeless = timeless )
112118
113119 # 3D:
114- rr .log (f"/{ prefix } " , rr .Transform3D (translation = transform .pos , mat3x3 = transform .rot .as_matrix ()), timeless = timeless )
120+ rr .log (f"/{ prefix } /3D/ " , rr .Transform3D (translation = transform_Viz_Trace .pos , mat3x3 = transform_Viz_Trace .rot .as_matrix ()), timeless = timeless )
115121
116122 (X_WC , Xs_WO , vertices_O , faces , vertex_colors ) = model_args_to_densemodel_args (trace .get_args ())
117- Xs_WO = trace .strip ()["poses" ].inner .value # TODO: do this better
118123 vertices_W = jax .vmap (lambda X_WO , v_O : X_WO .apply (v_O ), in_axes = (0 , 0 ))(Xs_WO , vertices_O )
119124 N = vertices_O .shape [0 ]
120125 f = jax .vmap (lambda i , f : f + i * vertices_O .shape [1 ], in_axes = (0 , 0 ))(jnp .arange (N ), faces )
121126 f = f .reshape (- 1 , 3 )
122127
123- rr .log (f"/{ prefix } /mesh" , rr .Mesh3D (
124- vertex_positions = vertices_W .reshape (- 1 , 3 ),
125- triangle_indices = f ,
126- vertex_colors = vertex_colors .reshape (- 1 , 3 )
128+ rr .log (f"/{ prefix } /3D/ mesh" , rr .Mesh3D (
129+ vertex_positions = np . array ( vertices_W .reshape (- 1 , 3 ) ),
130+ triangle_indices = np . array ( f ) ,
131+ vertex_colors = np . array ( vertex_colors .reshape (- 1 , 3 ) )
127132 ), timeless = timeless )
128133
129- rr .log (f"/{ prefix } /camera" ,
134+ rr .log (f"/{ prefix } /3D/ camera" ,
130135 rr .Pinhole (
131136 focal_length = [float (renderer .fx ), float (renderer .fy )],
132137 width = renderer .width ,
133138 height = renderer .height ,
134139 principal_point = jnp .array ([renderer .cx , renderer .cy ]),
135140 ), timeless = timeless
136141 )
137- rr .log (f"/{ prefix } /camera" , rr .Transform3D (translation = X_WC .pos , mat3x3 = X_WC .rot .as_matrix ()), timeless = timeless )
142+ rr .log (f"/{ prefix } /3D/ camera" , rr .Transform3D (translation = X_WC .pos , mat3x3 = X_WC .rot .as_matrix ()), timeless = timeless )
138143 xyzs_C = utils .xyz_from_depth (observed_rgbd [:, :, 3 ], renderer .fx , renderer .fy , renderer .cx , renderer .cy )
139144 xyzs_W = X_WC .apply (xyzs_C )
140- rr .log (f"/{ prefix } /gt_pointcloud" , rr .Points3D (
141- positions = xyzs_W .reshape (- 1 ,3 ),
142- colors = observed_rgbd [:, :, :3 ].reshape (- 1 ,3 ),
143- radii = 0.001 * jnp .ones (xyzs_W .reshape (- 1 ,3 ).shape [0 ])),
145+ rr .log (f"/{ prefix } /3D/ gt_pointcloud" , rr .Points3D (
146+ positions = np . array ( xyzs_W .reshape (- 1 ,3 ) ),
147+ colors = np . array ( observed_rgbd [:, :, :3 ].reshape (- 1 ,3 ) ),
148+ radii = 0.001 * np .ones (xyzs_W .reshape (- 1 ,3 ).shape [0 ])),
144149 timeless = timeless
145150 )
146151
147152 patch_centers_W = jax .vmap (lambda X_WO : X_WO .pos )(Xs_WO )
148153 rr .log (
149- f"/{ prefix } /patch_centers_W " ,
150- rr .Points3D (positions = patch_centers_W , colors = jnp .array ([0. , 0. , 1. ]), radii = 0.003 ),
154+ f"/{ prefix } /3D/patch_centers " ,
155+ rr .Points3D (positions = np . array ( patch_centers_W ) , colors = np .array ([0. , 0. , 1. ]), radii = 0.003 ),
151156 timeless = timeless
152157 )
0 commit comments