Skip to content

Commit 59717d4

Browse files
Merge pull request #47 from probcomp/gm/task_scripts
Progress on Tracks->Segmentation and Video->[Tracks+Segmentation] Tasks, and script to visualize all tasks+solvers
2 parents 7f1fa15 + 6911ee6 commit 59717d4

25 files changed

Lines changed: 587 additions & 196 deletions

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
*.egg-info
22
*.pyc
33
*.png
4-
assets/*
4+
assets/shared_data_bucket/*
5+
assets/test_results/*
56
**/.ipynb_checkpoints

b3d/chisight/dense/model.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import b3d.utils as utils
77
import b3d.chisight.dense.differentiable_renderer as rendering
88
import rerun as rr
9+
import numpy as np
910

1011
def uniformpose_meshes_to_image_model__factory(likelihood):
1112
"""
@@ -77,15 +78,15 @@ def rr_log_uniformpose_meshes_to_image_model_trace(trace, renderer, **kwargs):
7778
"""
7879
return rr_log_meshes_to_image_model_trace(trace, renderer, **kwargs,
7980
model_args_to_densemodel_args=(
80-
lambda args: (trace.get_choices()["camera_pose"], trace.get_choices()["poses"], *args)
81+
lambda args: (trace.get_choices()["camera_pose"], trace.get_choices()("poses").c.v, *args)
8182
))
8283

8384
def rr_log_meshes_to_image_model_trace(
8485
trace, renderer,
8586
prefix="trace",
8687
timeless=False,
8788
model_args_to_densemodel_args=(lambda x: x),
88-
transform=Pose.identity()
89+
transform_Viz_Trace=Pose.identity()
8990
):
9091
"""
9192
Log to rerun a visualization of a trace from `meshes_to_image_model`.
@@ -94,11 +95,16 @@ def rr_log_meshes_to_image_model_trace(
9495
to visualize traces from other models that have the same return value as `meshes_to_image_model`.
9596
This function will call `model_args_to_densemodel_args` on the arguments of the given trace,
9697
and should produce arguments of the form accepted by `meshes_to_image_model`.
98+
99+
The argument `transform_Viz_Trace` can be used to visualize the trace at a transformed
100+
coordinate frame. `transform_Viz_Trace` is a Pose object so that for a 3D point
101+
`point_Trace` in the trace, `transform_Viz_Trace.apply(point_Trace)` is the corresponding
102+
3D point in the visualizer.
97103
"""
98104
# 2D:
99105
(observed_rgbd, metadata) = trace.get_retval()
100-
rr.log(f"/{prefix}/rgb/observed", rr.Image(observed_rgbd[:, :, :3]), timeless=timeless)
101-
rr.log(f"/{prefix}/depth/observed", rr.DepthImage(observed_rgbd[:, :, 3]), timeless=timeless)
106+
rr.log(f"/{prefix}/rgb/observed", rr.Image(np.array(observed_rgbd[:, :, :3])), timeless=timeless)
107+
rr.log(f"/{prefix}/depth/observed", rr.DepthImage(np.array(observed_rgbd[:, :, 3])), timeless=timeless)
102108

103109
# Visualization path for the average render,
104110
# if the likelihood metadata contains the output of the differentiable renderer.
@@ -107,46 +113,45 @@ def rr_log_meshes_to_image_model_trace(
107113
avg_obs = rendering.dist_params_to_average(weights, attributes, jnp.zeros(4))
108114
avg_obs_rgb_clipped = jnp.clip(avg_obs[:, :, :3], 0, 1)
109115
avg_obs_depth_clipped = jnp.clip(avg_obs[:, :, 3], 0, 1)
110-
rr.log(f"/{prefix}/rgb/average_render", rr.Image(avg_obs_rgb_clipped), timeless=timeless)
111-
rr.log(f"/{prefix}/depth/average_render", rr.DepthImage(avg_obs_depth_clipped), timeless=timeless)
116+
rr.log(f"/{prefix}/rgb/average_render", rr.Image(np.array(avg_obs_rgb_clipped)), timeless=timeless)
117+
rr.log(f"/{prefix}/depth/average_render", rr.DepthImage(np.array(avg_obs_depth_clipped)), timeless=timeless)
112118

113119
# 3D:
114-
rr.log(f"/{prefix}", rr.Transform3D(translation=transform.pos, mat3x3=transform.rot.as_matrix()), timeless=timeless)
120+
rr.log(f"/{prefix}/3D/", rr.Transform3D(translation=transform_Viz_Trace.pos, mat3x3=transform_Viz_Trace.rot.as_matrix()), timeless=timeless)
115121

116122
(X_WC, Xs_WO, vertices_O, faces, vertex_colors) = model_args_to_densemodel_args(trace.get_args())
117-
Xs_WO = trace.strip()["poses"].inner.value # TODO: do this better
118123
vertices_W = jax.vmap(lambda X_WO, v_O: X_WO.apply(v_O), in_axes=(0, 0))(Xs_WO, vertices_O)
119124
N = vertices_O.shape[0]
120125
f = jax.vmap(lambda i, f: f + i*vertices_O.shape[1], in_axes=(0, 0))(jnp.arange(N), faces)
121126
f = f.reshape(-1, 3)
122127

123-
rr.log(f"/{prefix}/mesh", rr.Mesh3D(
124-
vertex_positions=vertices_W.reshape(-1, 3),
125-
triangle_indices=f,
126-
vertex_colors=vertex_colors.reshape(-1, 3)
128+
rr.log(f"/{prefix}/3D/mesh", rr.Mesh3D(
129+
vertex_positions=np.array(vertices_W.reshape(-1, 3)),
130+
triangle_indices=np.array(f),
131+
vertex_colors=np.array(vertex_colors.reshape(-1, 3))
127132
), timeless=timeless)
128133

129-
rr.log(f"/{prefix}/camera",
134+
rr.log(f"/{prefix}/3D/camera",
130135
rr.Pinhole(
131136
focal_length=[float(renderer.fx), float(renderer.fy)],
132137
width=renderer.width,
133138
height=renderer.height,
134139
principal_point=jnp.array([renderer.cx, renderer.cy]),
135140
), timeless=timeless
136141
)
137-
rr.log(f"/{prefix}/camera", rr.Transform3D(translation=X_WC.pos, mat3x3=X_WC.rot.as_matrix()), timeless=timeless)
142+
rr.log(f"/{prefix}/3D/camera", rr.Transform3D(translation=X_WC.pos, mat3x3=X_WC.rot.as_matrix()), timeless=timeless)
138143
xyzs_C = utils.xyz_from_depth(observed_rgbd[:, :, 3], renderer.fx, renderer.fy, renderer.cx, renderer.cy)
139144
xyzs_W = X_WC.apply(xyzs_C)
140-
rr.log(f"/{prefix}/gt_pointcloud", rr.Points3D(
141-
positions=xyzs_W.reshape(-1,3),
142-
colors=observed_rgbd[:, :, :3].reshape(-1,3),
143-
radii = 0.001*jnp.ones(xyzs_W.reshape(-1,3).shape[0])),
145+
rr.log(f"/{prefix}/3D/gt_pointcloud", rr.Points3D(
146+
positions=np.array(xyzs_W.reshape(-1,3)),
147+
colors=np.array(observed_rgbd[:, :, :3].reshape(-1,3)),
148+
radii = 0.001*np.ones(xyzs_W.reshape(-1,3).shape[0])),
144149
timeless=timeless
145150
)
146151

147152
patch_centers_W = jax.vmap(lambda X_WO: X_WO.pos)(Xs_WO)
148153
rr.log(
149-
f"/{prefix}/patch_centers_W",
150-
rr.Points3D(positions=patch_centers_W, colors=jnp.array([0., 0., 1.]), radii=0.003),
154+
f"/{prefix}/3D/patch_centers",
155+
rr.Points3D(positions=np.array(patch_centers_W), colors=np.array([0., 0., 1.]), radii=0.003),
151156
timeless=timeless
152157
)

b3d/chisight/dense/patch_tracking.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from b3d import Pose
44
import b3d
55
import genjax
6+
from genjax import ChoiceMapBuilder as C
67
import b3d.chisight.dense.differentiable_renderer as r
78
import b3d.chisight.dense.model as m
89
import b3d.chisight.dense.likelihoods as likelihoods
@@ -78,15 +79,10 @@ def get_adam_optimization_patch_tracker(model, patch_vertices_P, patch_faces, pa
7879
def importance_from_pos_quat(positions, quaternions, observed_rgbd):
7980
key = jax.random.PRNGKey(0) # This value shouldn't matter, in the current model version.
8081
poses = jax.vmap(lambda pos, quat: Pose.from_vec(jnp.concatenate([pos, quat])), in_axes=(0, 0))(positions, quaternions)
81-
trace, weight = model.importance(
82-
key,
83-
genjax.ChoiceMap.d({
84-
"poses": genjax.ChoiceMap.idx(jnp.arange(positions.shape[0]), poses),
85-
"camera_pose": X_WC,
86-
"observed_image": {"observed_image": {"obs": observed_rgbd}}
87-
}),
88-
(patch_vertices_P, patch_faces, patch_vertex_colors)
89-
)
82+
cm = jax.vmap(lambda i: C["poses", i].set(poses[i]))(jnp.arange(poses.shape[0]))
83+
cm = cm.merge(C["camera_pose"].set(b3d.Pose.identity()))
84+
cm = cm.merge(C["observed_image", "observed_image", "obs"].set(observed_rgbd))
85+
trace, weight = model.importance(key, cm, (patch_vertices_P, patch_faces, patch_vertex_colors))
9086
return trace, weight
9187

9288
def weight_from_pos_quat(pos, quat, observed_rgbd):

tests/common/solver.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,11 @@ def visualize_solver_state(self):
2222
Visualize any information recorded by the solver during the last call to `solve`.
2323
This may log data to rerun, produce pyplots, etc.
2424
"""
25-
pass
25+
pass
26+
27+
@property
28+
def name(self) -> str:
29+
"""
30+
Returns the name of the solver.
31+
"""
32+
return self.__class__.__name__

tests/common/task.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ class Task:
77
- `score`
88
99
Optionally, task developers may also implement the following methods:
10+
- @property `name`
1011
- `assert_passing`
1112
- `visualize_task`
1213
- `visualize_solution`
1314
1415
The `run_and_score` and `run_tests` methods are automatically implemented from the above.
16+
17+
It is recommended that `Task` constrution should be very fast, and that
18+
data-loading or other slow operations should be performed lazily
19+
at the first call to `get_task_specification` or `visualize_task`.
1520
"""
1621

1722
### CORE TASK INTERFACE ###
@@ -38,6 +43,13 @@ def score(self, solution, **kwargs) -> "Any":
3843
"""
3944
raise NotImplementedError()
4045

46+
@property
47+
def name(self) -> str:
48+
"""
49+
Returns the name of the task.
50+
"""
51+
return self.__class__.__name__
52+
4153
def assert_passing(self, scores, **kwargs) -> None:
4254
"""
4355
Takes the output of `score` and makes assertions about the scores,
@@ -57,7 +69,7 @@ def assert_passing(self, scores, **kwargs) -> None:
5769
according to the default tolerances).
5870
"""
5971
raise NotImplementedError()
60-
72+
6173
def visualize_task(self):
6274
"""
6375
Visualize the task (but not the solution).
@@ -93,6 +105,7 @@ def run_and_score(self, solver, viz=False, **kwargs) -> dict:
93105
metrics = self.score(task_output, **kwargs)
94106
if viz:
95107
self.visualize_solution(task_output, metrics)
108+
solver.visualize_solver_state(task_spec)
96109
return metrics
97110

98111
def run_tests(self, solver, viz=False, **kwargs) -> None:
@@ -111,3 +124,18 @@ def run_tests(self, solver, viz=False, **kwargs) -> None:
111124
# Score and assess whether passing.
112125
metrics = self.run_and_score(solver, viz=viz, **kwargs)
113126
self.assert_passing(metrics, **kwargs)
127+
128+
def run_solver_and_make_all_visualizations(self, solver):
129+
"""
130+
Run the solver and make all visualizations.
131+
"""
132+
task_spec = self.get_task_specification()
133+
self.visualize_task()
134+
task_output = solver.solve(task_spec)
135+
metrics = self.score(task_output)
136+
solver.visualize_solver_state(task_spec)
137+
self.visualize_solution(task_output, metrics)
138+
return metrics
139+
140+
def __repr__(self):
141+
return f"{self.name}()"

tests/dense_model_unit_tests/triangle_depth_posterior/test_triangle_depth_posterior.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
for triangle_color in [jnp.array([1., 0., 0.]), jnp.array([0., 1., 0.])]
1111
]
1212

13-
@pytest.mark.parametrize("task_spec", task_specs[:3])
13+
# Only run one test for now, to prevent issues due to the current
14+
# memory leak in the renderer.
15+
@pytest.mark.parametrize("task_spec", task_specs[:1])
1416
def test(task_spec):
1517
task = TrianglePosteriorGridApproximationTask.default_scene_using_colors(*task_spec)
1618
task.run_tests(

tests/dense_model_unit_tests/triangle_depth_posterior/visualize.ipynb

Lines changed: 155 additions & 0 deletions
Large diffs are not rendered by default.

tests/sama4d/data_curation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,19 @@ def get_cheezitbox_scene_loader(n_frames=30):
5959

6060
# Scene manually constructed in Python: rotating cheezit box
6161
def ftd_from_rotating_cheezit_box(n_frames=30):
62-
(r, centers_2D_frame_0, centers_3D_W_over_time, Xs_WC, observed_rgbds) = load_rotating_cheezit_box_data(n_frames)
62+
(r, centers_2D_frame_0, centers_3D_W_over_time, poses_WC, observed_rgbds) = load_rotating_cheezit_box_data(n_frames)
6363
return b3d.io.FeatureTrackData(
6464
observed_keypoints_positions=jax.vmap(lambda positions_3D_W, X_WC: b3d.xyz_to_pixel_coordinates(
6565
X_WC.inv().apply(positions_3D_W), r.fx, r.fy, r.cx, r.cy
66-
), in_axes=(0, 0))(centers_3D_W_over_time, Xs_WC),
66+
), in_axes=(0, 0))(centers_3D_W_over_time, poses_WC),
6767
keypoint_visibility=jnp.ones((n_frames, centers_2D_frame_0.shape[0]), dtype=bool),
6868
camera_intrinsics=r.get_intrinsics_object().as_array(),
6969
rgbd_images=observed_rgbds,
7070
latent_keypoint_positions=centers_3D_W_over_time,
71-
camera_position=Xs_WC.pos,
72-
camera_quaternion=Xs_WC.xyzw
71+
camera_position=poses_WC.pos,
72+
camera_quaternion=poses_WC.xyzw,
73+
# Every point is assigned to one object (the cheez-it box)
74+
object_assignments=jnp.zeros(centers_2D_frame_0.shape[0], dtype=int)
7375
)
7476

7577
def load_rotating_cheezit_box_data(n_frames=30):
@@ -119,9 +121,9 @@ def load_rotating_cheezit_box_data(n_frames=30):
119121
lambda X_W_Bt: X_W_Bt.apply(centers_3D_B0)
120122
)(box_poses_W)
121123

122-
Xs_WC = jax.vmap(lambda x: X_WC)(jnp.arange(n_frames))
124+
poses_WC = jax.vmap(lambda x: X_WC)(jnp.arange(n_frames))
123125

124-
return (renderer, centers_2D_frame_0, centers_3D_W_over_time, Xs_WC, observed_rgbds)
126+
return (renderer, centers_2D_frame_0, centers_3D_W_over_time, poses_WC, observed_rgbds)
125127

126128
### Utils ###
127129

File renamed without changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import jax.numpy as jnp
2+
from tests.common.solver import Solver
3+
4+
class DummyTracksToSegmentationSolver(Solver):
5+
def solve(self, task_spec):
6+
# assign every keypoint to the same object
7+
# (called object 0)
8+
return jnp.zeros(
9+
task_spec["keypoint_tracks_2D"].shape[1],
10+
dtype=int
11+
)
12+
13+
def visualize_solver_state(self, task_spec):
14+
pass

0 commit comments

Comments
 (0)