Skip to content

Commit 3dc1321

Browse files
authored
Merge pull request #19 from GradientSpaces/version_1.0
Update to version 1.0: - Improve training speed (9-12% faster) by torch.compile and training stability by timestep clipping. - Fix bugs in configs, RK2 sampler, and validation.
2 parents 167bfed + 391d1b3 commit 3dc1321

13 files changed

Lines changed: 281 additions & 273 deletions

File tree

README.md

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,15 @@
1515

1616

1717
## 🔔 News
18-
- [July 15, 2025] Improve training stability and fix bugs.
19-
- [July 9, 2025] Released training codes.
20-
- [July 1, 2025] Released model checkpoints and inference codes.
18+
- [July 25, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes:
19+
- Improved model speed (9-12% faster) and training stability.
20+
- Fixed bugs in configs, RK2 sampler, and validation.
21+
- Simplified point cloud packing and shaping.
22+
- Checkpoints are compatible with the previous version.
23+
24+
- [July 9, 2025] **Version 0.1**: Release training codes.
25+
26+
- [July 1, 2025] Initial release of the model checkpoints and inference codes.
2127

2228
## Overview
2329

@@ -74,7 +80,7 @@ This saves images of the input (unposed) parts and multiple generations for poss
7480

7581
- **Renderer**: We use [Mitsuba](https://mitsuba.readthedocs.io/en/latest/) for high quality ray-traced rendering, as shown above. For a faster rendering, please switch to [PyTorch3D PointsRasterizer](https://pytorch3d.readthedocs.io/en/latest/modules/renderer/points/rasterizer.html#pytorch3d.renderer.points.rasterizer.PointsRasterizer) by adding `visualizer.renderer=pytorch3d`. To disable rendering, use `visualizer.renderer=none`. More rendering options are available in [config/visualizer](config/visualizer/flow.yaml).
7682

77-
- **Sampler**: We support Euler, RK2 (default), and RK4 samplers for inference, set `model.inference_sampler={euler, rk2, rk4}` accordingly.
83+
- **Sampler**: We support Euler (default), RK2, and RK4 samplers for inference, set `model.inference_sampler={euler, rk2, rk4}` accordingly.
7884

7985
**Overlap Prediction:** To visualize the overlap probabilities predicted by the encoder, please run:
8086

config/RPF_base_main_10k.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Training Rectified Point Flow
2+
3+
defaults:
4+
- model: rectified_point_flow
5+
- data: ikea_partnet_everyday_twobytwo_modelnet_tudl
6+
- trainer: main
7+
- loggers: wandb
8+
- _self_
9+
10+
# Random seed for reproducibility
11+
seed: 42
12+
13+
# Data root
14+
data_root: "../dataset"
15+
data:
16+
num_points_to_sample: 10000
17+
18+
# Experiment name and log directory
19+
experiment_name: RPF_base
20+
log_dir: ./output/${experiment_name}
21+
ckpt_path: ${log_dir}/last.ckpt
22+
hydra:
23+
run:
24+
dir: ${log_dir}
25+
26+
# Model settings
27+
model:
28+
encoder_ckpt: null
29+
flow_model_ckpt: null
30+
31+
flow_model:
32+
# For 10k points, we replace QK norm by softcapping for speeding up.
33+
attn_dtype: "bfloat16"
34+
softcap: 50.0
35+
qk_norm: False

config/model/rectified_point_flow.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ lr_scheduler:
1919
gamma: 0.5
2020

2121
timestep_sampling: "u_shaped"
22-
inference_sampler: "rk2"
23-
inference_sampling_steps: 20
22+
inference_sampler: "euler"
23+
inference_sampling_steps: 50
2424
n_generations: 1

rectified_point_flow/data/dataset.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def __getitem__(self, index):
9898
- rotations (P, 3, 3) float32: Rotation matrices.
9999
- translations (P, 3) float32: Translation vectors.
100100
- points_per_part (P) int64: Number of points per part.
101-
- scale (P) float32: Scale of the point clouds.
102-
- anchor_part (P) bool: Whether the part is an anchor part.
101+
- scales (1, ) float32: Scale of the point clouds.
102+
- anchor_parts (P) bool: Boolean array indicating anchor parts.
103+
- anchor_indices (N, ) bool: Boolean array indicating anchor points.
103104
- init_rotation (3, 3) float32: Initial rotation matrix of the pointclouds_gt, used for recovering the original data.
104105
105106
Note:
@@ -323,7 +324,6 @@ def _proc_part(i):
323324
pts_per_part = pad_data(counts, self.max_parts)
324325
rots = pad_data(np.stack(rots), self.max_parts)
325326
trans = pad_data(np.stack(trans), self.max_parts)
326-
scale = pad_data(np.array([scale] * n_parts), self.max_parts)
327327

328328
# Use the largest part as the anchor part
329329
anchor = np.zeros(self.max_parts, bool)
@@ -347,6 +347,13 @@ def _proc_part(i):
347347
rots[extra_idx] = np.eye(3)
348348
trans[extra_idx] = np.zeros(3)
349349

350+
# Broadcast anchor part to points
351+
anchor_indices = np.zeros(self.num_points_to_sample, bool)
352+
for i in range(n_parts):
353+
if anchor[i]:
354+
st, ed = offsets[i], offsets[i + 1]
355+
anchor_indices[st:ed] = True
356+
350357
results = {}
351358
for key in ["index", "name", "overlap_threshold"]:
352359
results[key] = data[key]
@@ -360,8 +367,9 @@ def _proc_part(i):
360367
results["rotations"] = rots.astype(np.float32)
361368
results["translations"] = trans.astype(np.float32)
362369
results["points_per_part"] = pts_per_part.astype(np.int64)
363-
results["scale"] = scale.astype(np.float32)
364-
results["anchor_part"] = anchor.astype(bool)
370+
results["scales"] = np.array(scale, dtype=np.float32)
371+
results["anchor_parts"] = anchor.astype(bool)
372+
results["anchor_indices"] = anchor_indices.astype(bool)
365373
results["init_rotation"] = init_rot.astype(np.float32)
366374

367375
return results

rectified_point_flow/eval/evaluator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def _compute_metrics(
2525
pts = data["pointclouds"] # (B, N, 3)
2626
pts_gt = data["pointclouds_gt"] # (B, N, 3)
2727
points_per_part = data["points_per_part"] # (B, P)
28-
anchor_part = data["anchor_part"] # (B, P)
29-
scale = data["scale"][:, 0] # (B,)
28+
anchor_parts = data["anchor_parts"] # (B, P)
29+
scales = data["scales"] # (B,)
3030

31-
# Rescale to original scale
31+
# Rescale to original scales
3232
B, _, _ = pts_gt.shape
3333
pointclouds_pred = pointclouds_pred.view(B, -1, 3)
34-
pts_gt_rescaled = pts_gt * scale.view(B, 1, 1)
35-
pts_pred_rescaled = pointclouds_pred * scale.view(B, 1, 1)
34+
pts_gt_rescaled = pts_gt * scales.view(B, 1, 1)
35+
pts_pred_rescaled = pointclouds_pred * scales.view(B, 1, 1)
3636

3737
object_cd = compute_object_cd(pts_gt_rescaled, pts_pred_rescaled)
3838
part_acc, matched_parts = compute_part_acc(pts_gt_rescaled, pts_pred_rescaled, points_per_part)
@@ -43,7 +43,7 @@ def _compute_metrics(
4343

4444
if rotations_pred is not None and translations_pred is not None:
4545
rot_errors, trans_errors = compute_transform_errors(
46-
pts, pts_gt, rotations_pred, translations_pred, points_per_part, anchor_part, matched_parts, scale,
46+
pts, pts_gt, rotations_pred, translations_pred, points_per_part, anchor_parts, matched_parts, scales,
4747
)
4848
rot_recalls = self._recall_at_thresholds(rot_errors, [5, 10])
4949
trans_recalls = self._recall_at_thresholds(trans_errors, [0.01, 0.05])
@@ -84,7 +84,7 @@ def _save_single_result(
8484
"dataset": dataset_name,
8585
"num_parts": int(data["num_parts"][idx]),
8686
"generation_idx": generation_idx,
87-
"scale": float(data["scale"][idx, 0]),
87+
"scales": float(data["scales"][idx]),
8888
}
8989
entry.update({k: float(v[idx]) for k, v in metrics.items()})
9090

@@ -107,7 +107,7 @@ def run(
107107
Args:
108108
data: Input data dictionary, containing:
109109
pointclouds_gt (B, N, 3): Ground truth point clouds.
110-
scale (B,): Scale factors.
110+
scales (B,): scales factors.
111111
points_per_part (B, P): Points per part.
112112
name (B,): Object names.
113113
dataset_name (B,): Dataset names.

rectified_point_flow/flow_model/embedding.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -117,33 +117,42 @@ def __init__(self, in_dim: int, embed_dim: int, multires: int = 10):
117117

118118
def forward(
119119
self,
120-
x: torch.Tensor, # (n_points, 3)
121-
latent: dict, # PointTransformer's `Point` instance
122-
scale: torch.Tensor, # (n_valid_parts, )
120+
x: torch.Tensor,
121+
latent: dict,
122+
scales: torch.Tensor,
123123
) -> torch.Tensor:
124124
"""Generate PointCloudEmbedding from the input.
125125
126126
Args:
127-
x: Input coordinates tensor of shape (n_points, 3).
128-
latent: Dictionary containing point cloud features and metadata.
129-
scale: Scale factors of shape (n_valid_parts, 1).
127+
x (B, N, 3): Noise point coordinates at timestep t.
128+
latent: PointTransformer's Point instance of conditional point cloud:
129+
- "coord" (n_points, 3): Point coordinates
130+
- "normal" (n_points, 3): Point normals
131+
- "feat" (n_points, in_dim): Point features
132+
scales (B, ): Scale factor for the point cloud.
130133
131134
Returns:
132-
Shape embeddings of shape (n_points, embed_dim).
135+
Shape embeddings of shape (B, N, dim).
133136
"""
137+
B, N, _ = x.shape
138+
134139
# Coordinate of noise PCs
135-
x_pos_emb = self.noise_embedding.embed(x) # (n_points, emb_dim)
140+
x_pos_emb = self.noise_embedding.embed(x) # (B, N, dim)
136141

137142
# Coordinate of condition PCs
138-
c_pos_emb = self.coord_embedding.embed(latent["coord"]) # (n_points, emb_dim)
139-
143+
coord = latent["coord"].view(B, N, 3)
144+
c_pos_emb = self.coord_embedding.embed(coord) # (B, N, dim)
145+
140146
# Normal of condition PCs
141-
normal_emb = self.normal_embedding.embed(latent["normal"]) # (n_points, emb_dim)
147+
normal = latent["normal"].view(B, N, 3)
148+
normal_emb = self.normal_embedding.embed(normal) # (B, N, dim)
142149

143150
# Scale of condition PCs
144-
scale_emb = self.scale_embedding.embed(scale.unsqueeze(-1)) # (n_valid_parts, emb_dim)
145-
scale_emb = scale_emb[latent["batch"]] # (n_points, emb_dim)
151+
scale_emb = self.scale_embedding.embed(scales.unsqueeze(-1)) # (B, 1, dim)
152+
scale_emb = scale_emb.unsqueeze(1).expand(B, N, -1) # (B, N, dim)
146153

147154
# Concatenate with point features
148-
x = torch.cat([latent["feat"], c_pos_emb, x_pos_emb, normal_emb, scale_emb], dim=-1)
155+
feat = latent["feat"].view(B, N, -1) # (B, N, in_dim)
156+
x = torch.cat([feat, c_pos_emb, x_pos_emb, normal_emb, scale_emb], dim=-1)
157+
149158
return self.emb_proj(x)

0 commit comments

Comments
 (0)