88from b3d .chisight .dense .dense_likelihood import make_dense_observation_model , DenseImageLikelihoodArgs
99from b3d import Pose , Mesh
1010from b3d .chisight .sparse .gps_utils import add_dummy_var
11- from b3d .chisight . sparse .pose_utils import uniform_pose_in_ball
11+ from b3d .pose .pose_utils import uniform_pose_in_ball
1212dummy_mapped_uniform_pose = add_dummy_var (uniform_pose_in_ball ).vmap (in_axes = (0 ,None ,None ,None ))
1313
1414
@@ -113,7 +113,8 @@ def particle_system_state_step(carried_state, _):
113113
114114@gen
115115def latent_particle_model (
116- num_timesteps , # const object
116+ max_num_timesteps , # const object
117+ num_timesteps ,
117118 num_particles , # const object
118119 num_clusters , # const object
119120 relative_particle_poses_prior_params ,
@@ -132,32 +133,49 @@ def latent_particle_model(
132133 camera_pose_prior_params
133134 ) @ "state0"
134135
135- final_state , scan_retvals = particle_system_state_step .scan (n = (num_timesteps .const - 1 ))(state0 , None ) @ "states1+"
136+ masked_final_state , masked_scan_retvals = b3d .modeling_utils .masked_scan_combinator (
137+ particle_system_state_step ,
138+ n = (max_num_timesteps .const - 1 )
139+ )(
140+ state0 ,
141+ genjax .Mask (
142+ # This next line tells the scan combinator how many timesteps to run
143+ jnp .arange (max_num_timesteps .const - 1 ) < num_timesteps - 1 ,
144+ jnp .zeros (max_num_timesteps .const - 1 )
145+ )
146+ ) @ "states1+"
147+
136148
137149 # concatenate each element of init_retval, scan_retvals
138- return jax .tree .map (
150+ concatenated_states_possibly_invalid = jax .tree .map (
139151 lambda t1 , t2 : jnp .concatenate ([t1 [None , :], t2 ], axis = 0 ),
140- init_retval , scan_retvals
152+ init_retval , masked_scan_retvals .value
153+ )
154+ masked_concatenated_states = genjax .Mask (
155+ jnp .concatenate ([jnp .array ([True ]), masked_scan_retvals .flag ]),
156+ concatenated_states_possibly_invalid
141157 )
158+ return masked_concatenated_states
142159
143160@genjax .gen
144161def sparse_observation_model (particle_absolute_poses , camera_pose , visibility , instrinsics , sigma ):
145162 # TODO: add visibility
146163 uv = b3d .camera .screen_from_world (particle_absolute_poses .pos , camera_pose , instrinsics .const )
147- uv_ = genjax .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
164+ uv_ = b3d . modeling_utils .normal (uv , jnp .tile (sigma , uv .shape )) @ "sensor_coordinates"
148165 return uv_
149166
150167@genjax .gen
151168def sparse_gps_model (latent_particle_model_args , obs_model_args ):
152- # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
153- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
154- obs = sparse_observation_model .vmap (in_axes = (0 , 0 , 0 , None , None ))(
155- particle_dynamics_summary ["absolute_particle_poses" ],
156- particle_dynamics_summary ["camera_pose" ],
157- particle_dynamics_summary ["vis_mask" ],
169+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
170+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
171+ masked_obs = sparse_observation_model .mask ().vmap (in_axes = (0 , 0 , 0 , 0 , None , None ))(
172+ masked_particle_dynamics_summary .flag ,
173+ _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ],
174+ _UNSAFE_particle_dynamics_summary ["camera_pose" ],
175+ _UNSAFE_particle_dynamics_summary ["vis_mask" ],
158176 * obs_model_args
159177 ) @ "observation"
160- return (particle_dynamics_summary , obs )
178+ return (masked_particle_dynamics_summary , masked_obs )
161179
162180
163181
@@ -166,15 +184,17 @@ def make_dense_gps_model(renderer):
166184
167185 @genjax .gen
168186 def dense_gps_model (latent_particle_model_args , dense_likelihood_args ):
169- # (b3d.camera.Intrinsics.from_array(jnp.array([1.0, 1.0, 1.0, 1.0])), 0.1)
170- particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
171- absolute_particle_poses_last_frame = particle_dynamics_summary ["absolute_particle_poses" ][- 1 ]
172- camera_pose_last_frame = particle_dynamics_summary ["camera_pose" ][- 1 ]
187+ masked_particle_dynamics_summary = latent_particle_model (* latent_particle_model_args ) @ "particle_dynamics"
188+ _UNSAFE_particle_dynamics_summary = masked_particle_dynamics_summary .value
189+
190+ last_timestep_index = jnp .sum (masked_particle_dynamics_summary .flag ) - 1
191+ absolute_particle_poses_last_frame = _UNSAFE_particle_dynamics_summary ["absolute_particle_poses" ][last_timestep_index ]
192+ camera_pose_last_frame = _UNSAFE_particle_dynamics_summary ["camera_pose" ][last_timestep_index ]
173193 absolute_particle_poses_in_camera_frame = camera_pose_last_frame .inv () @ absolute_particle_poses_last_frame
174194
175195 (meshes , likelihood_args ) = dense_likelihood_args
176196 merged_mesh = Mesh .transform_and_merge_meshes (meshes , absolute_particle_poses_in_camera_frame )
177197 image = dense_observation_model (merged_mesh , likelihood_args ) @ "observation"
178- return (particle_dynamics_summary , image )
198+ return (masked_particle_dynamics_summary , image )
179199
180200 return dense_gps_model
0 commit comments