Skip to content

Commit b7dacec

Browse files
committed
ENH: Update from copilot
1 parent e6f602f commit b7dacec

6 files changed

Lines changed: 149 additions & 25 deletions

File tree

experiments/Heart-GatedCT_To_USD/1-register_images.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
"metadata": {},
104104
"outputs": [],
105105
"source": [
106-
"for i in range(0, 21, 1): # Process every 4th slice to save time testing\n",
106+
"for i in range(0, 21, 4): # Process every 4th slice to save time testing\n",
107107
" print(f\"Processing slice {i:03d}\")\n",
108108
" moving_image = itk.imread(os.path.join(data_dir, f\"slice_{i:03d}.mha\"))\n",
109109
" result = seg.segment(moving_image, contrast_enhanced_study=True)\n",
@@ -239,4 +239,4 @@
239239
},
240240
"nbformat": 4,
241241
"nbformat_minor": 5
242-
}
242+
}

src/physiomotion4d/register_models_pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181
pca_template_model_point_subsample: int = 4,
8282
pre_pca_transform: Optional[itk.Transform] = None,
8383
fixed_distance_map: Optional[itk.Image] = None,
84-
fixed_model: Optional[pv.UnstructuredGrid] = None,
84+
fixed_model: Optional[pv.UnstructuredGrid | pv.PolyData] = None,
8585
reference_image: Optional[itk.Image] = None,
8686
log_level: int | str = logging.INFO,
8787
):
@@ -186,7 +186,7 @@ def from_json(
186186
pca_template_model_point_subsample: int = 4,
187187
pre_pca_transform: Optional[itk.Transform] = None,
188188
fixed_distance_map: Optional[itk.Image] = None,
189-
fixed_model: Optional[pv.UnstructuredGrid] = None,
189+
fixed_model: Optional[pv.UnstructuredGrid | pv.PolyData] = None,
190190
reference_image: Optional[itk.Image] = None,
191191
log_level: int | str = logging.INFO,
192192
) -> Self:

src/physiomotion4d/usd_tools.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,21 @@ def apply_colormap_from_primvar(
909909

910910
# Value range: use provided intensity_range or compute from data
911911
if intensity_range is not None:
912-
vmin, vmax = intensity_range
912+
try:
913+
vmin, vmax = float(intensity_range[0]), float(intensity_range[1])
914+
except (TypeError, IndexError) as e:
915+
raise ValueError(
916+
"intensity_range must be a sequence of two floats (vmin, vmax)"
917+
) from e
918+
if not (np.isfinite(vmin) and np.isfinite(vmax)):
919+
raise ValueError(
920+
f"intensity_range values must be finite; got ({vmin}, {vmax})"
921+
)
922+
if vmin >= vmax:
923+
vmin, vmax = vmax, vmin
924+
self.log_info(
925+
f"intensity_range was (vmax, vmin); swapped to {vmin:.6g} to {vmax:.6g}"
926+
)
913927
self.log_info(f"Using specified intensity range: {vmin:.6g} to {vmax:.6g}")
914928
else:
915929
all_values = np.concatenate([s for _, s in scalar_samples])
@@ -952,7 +966,7 @@ def apply_colormap_from_primvar(
952966
normalized = np.full_like(scalar, 0.5)
953967

954968
if use_sigmoid_scale:
955-
normalized = 1 / (1 + np.exp(-4 * normalized))
969+
normalized = 1 / (1 + np.exp(-4 * (normalized - 0.5)))
956970

957971
normalized = np.clip(normalized, 0.0, 1.0)
958972

@@ -1046,6 +1060,7 @@ def set_solid_display_color(
10461060
color_array = Vt.Vec3fArray([vec] * n_points)
10471061
display_color_pv.Set(color_array)
10481062
else:
1063+
default_point_count: int | None = None
10491064
for tc in time_codes:
10501065
# Normalize to a Usd.TimeCode
10511066
usd_tc = tc if isinstance(tc, Usd.TimeCode) else Usd.TimeCode(tc)
@@ -1059,12 +1074,20 @@ def set_solid_display_color(
10591074
n_points = len(pts) if pts is not None else 0
10601075
if n_points == 0:
10611076
continue
1077+
if default_point_count is None:
1078+
default_point_count = n_points
10621079
color_array = Vt.Vec3fArray([vec] * n_points)
10631080
if usd_tc.IsDefault():
10641081
display_color_pv.Set(color_array)
10651082
else:
10661083
display_color_pv.Set(color_array, usd_tc)
10671084

1085+
# Author a default (time-independent) value so consumers that query the
1086+
# default when not time-scrubbing still see the solid color.
1087+
if default_point_count is not None:
1088+
default_color_array = Vt.Vec3fArray([vec] * default_point_count)
1089+
display_color_pv.Set(default_color_array)
1090+
10681091
if bind_vertex_color_material:
10691092
self._ensure_vertex_color_material(stage, mesh_prim)
10701093
if stage_path:

src/physiomotion4d/vtk_to_usd/converter.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,83 @@ def convert_file(
112112

113113
return stage
114114

115+
def convert_files_static(
116+
self,
117+
vtk_files: Sequence[str | Path],
118+
output_usd: str | Path,
119+
mesh_name: str = "Mesh",
120+
material: Optional[MaterialData] = None,
121+
extract_surface: bool = True,
122+
) -> Usd.Stage:
123+
"""Convert multiple VTK files into one static USD stage (no time samples).
124+
125+
All meshes from all files are added to the scene at default time. Use this
126+
when multiple files are provided but filenames do not match a time-series
127+
pattern, so they should be combined as a single static scene rather than
128+
time steps.
129+
130+
Args:
131+
vtk_files: List of VTK file paths
132+
output_usd: Path to output USD file
133+
mesh_name: Base name for meshes (each file/part gets a unique name)
134+
material: Optional material data. If None, uses default.
135+
extract_surface: For .vtu files, whether to extract surface
136+
137+
Returns:
138+
Usd.Stage: Created USD stage
139+
"""
140+
if len(vtk_files) == 0:
141+
raise ValueError("Empty file list")
142+
143+
logger.info(
144+
"Converting %d files to static USD (no time samples): %s",
145+
len(vtk_files),
146+
output_usd,
147+
)
148+
149+
# Create USD stage once (no time range)
150+
self._create_stage(output_usd)
151+
stage = self.stage
152+
mesh_converter = self.mesh_converter
153+
material_mgr = self.material_mgr
154+
assert stage is not None
155+
assert mesh_converter is not None
156+
assert material_mgr is not None
157+
158+
if material is not None:
159+
material_mgr.get_or_create_material(material)
160+
161+
for file_idx, vtk_file in enumerate(vtk_files):
162+
mesh_data = read_vtk_file(vtk_file, extract_surface=extract_surface)
163+
if material is not None:
164+
mesh_data.material_id = material.name
165+
166+
# Unique base per file to avoid prim path collisions
167+
file_base = f"{mesh_name}_{file_idx}"
168+
169+
if self.settings.separate_objects_by_connectivity:
170+
parts = split_mesh_data_by_connectivity(mesh_data, mesh_name=file_base)
171+
for _idx, (part_data, base_name) in enumerate(parts):
172+
mesh_path = f"/World/Meshes/{base_name}"
173+
self._ensure_parent_path(mesh_path)
174+
mesh_converter.create_mesh(part_data, mesh_path, bind_material=True)
175+
elif self.settings.separate_objects_by_cell_type:
176+
parts = split_mesh_data_by_cell_type(mesh_data, mesh_name=file_base)
177+
for idx, (part_data, base_name) in enumerate(parts):
178+
prim_name = f"{base_name}_{idx}"
179+
mesh_path = f"/World/Meshes/{prim_name}"
180+
self._ensure_parent_path(mesh_path)
181+
mesh_converter.create_mesh(part_data, mesh_path, bind_material=True)
182+
else:
183+
mesh_path = f"/World/Meshes/{file_base}"
184+
self._ensure_parent_path(mesh_path)
185+
mesh_converter.create_mesh(mesh_data, mesh_path, bind_material=True)
186+
187+
stage.Save()
188+
logger.info(f"Saved USD file: {output_usd}")
189+
190+
return stage
191+
115192
def convert_sequence(
116193
self,
117194
vtk_files: Sequence[str | Path],

src/physiomotion4d/vtk_to_usd/mesh_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,17 @@ def split_mesh_data_by_cell_type(
3434
) -> list[tuple[MeshData, str]]:
3535
"""Split MeshData into one mesh per distinct face vertex count (cell type).
3636
37-
Each part is named by cell type (e.g. Triangle, Quad, Hexahedron). The caller
38-
should append a unique number to form final prim names (e.g. Triangle_0, Quad_0).
37+
Each part is named as mesh_name plus the cell type (e.g. MeshName_Triangle,
38+
MeshName_Quad). The caller should append a unique number to form final prim
39+
names (e.g. MeshName_Triangle_0, MeshName_Quad_0).
3940
4041
Args:
4142
mesh_data: Single mesh that may contain mixed cell types.
43+
mesh_name: Name of the source mesh; used as prefix in returned base_name.
4244
4345
Returns:
4446
List of (MeshData, base_name) for each cell type present. base_name is
45-
the cell type name (e.g. "Triangle", "Quad").
47+
mesh_name + "_" + cell type name (e.g. "MeshName_Triangle", "MeshName_Quad").
4648
"""
4749
counts = np.asarray(mesh_data.face_vertex_counts, dtype=np.int32)
4850
indices = np.asarray(mesh_data.face_vertex_indices, dtype=np.int32)
@@ -282,14 +284,15 @@ def split_mesh_data_by_connectivity(
282284
"""Split MeshData into one mesh per connected component.
283285
284286
A connected component is a maximal set of cells that share vertices (directly
285-
or transitively). Components are named object1, object2, etc.
287+
or transitively). Components are named mesh_name_object1, mesh_name_object2, etc.
286288
287289
Args:
288290
mesh_data: Single mesh that may contain multiple disconnected parts.
291+
mesh_name: Name of the source mesh; used as prefix in returned base_name.
289292
290293
Returns:
291294
List of (MeshData, base_name) for each component. base_name is
292-
"object1", "object2", ...
295+
mesh_name + "_objectN" (e.g. "MeshName_object1", "MeshName_object2", ...).
293296
"""
294297
counts = np.asarray(mesh_data.face_vertex_counts, dtype=np.int32)
295298
indices = np.asarray(mesh_data.face_vertex_indices, dtype=np.int32)

src/physiomotion4d/workflow_convert_vtk_to_usd.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,26 +27,30 @@
2727
def discover_time_series(
2828
paths: list[Path],
2929
pattern: str = r"\.t(\d+)\.(vtk|vtp|vtu)$",
30-
) -> list[tuple[int, Path]]:
30+
) -> tuple[list[tuple[int, Path]], bool]:
3131
"""Discover and sort time-series VTK files by extracted time index.
3232
3333
Args:
3434
paths: List of paths to VTK files
3535
pattern: Regex with one group for time step number (default matches .t123.vtk)
3636
3737
Returns:
38-
Sorted list of (time_step, path) tuples. If no match, returns [(0, p) for p in paths].
38+
(time_series, pattern_matched): Sorted list of (time_step, path) tuples, and
39+
a flag True if at least one path matched the pattern. If no path matches,
40+
time_series is [(0, p) for p in paths] and pattern_matched is False.
3941
"""
4042
time_series: list[tuple[int, Path]] = []
4143
regex = re.compile(pattern, re.IGNORECASE)
44+
pattern_matched = False
4245
for p in paths:
4346
match = regex.search(p.name)
4447
if match:
4548
time_series.append((int(match.group(1)), Path(p)))
49+
pattern_matched = True
4650
else:
4751
time_series.append((0, Path(p)))
4852
time_series.sort(key=lambda x: (x[0], str(x[1])))
49-
return time_series
53+
return time_series, pattern_matched
5054

5155

5256
AppearanceKind = Literal["solid", "anatomy", "colormap"]
@@ -141,17 +145,24 @@ def run(self) -> str:
141145
raise ValueError("vtk_files must not be empty")
142146

143147
# Discover time series
144-
time_series = discover_time_series(
148+
time_series, pattern_matched = discover_time_series(
145149
self.vtk_files, pattern=self.time_series_pattern
146150
)
147151
time_steps = [t for t, _ in time_series]
148152
time_codes = [float(t) for t in time_steps]
149153
paths_ordered = [p for _, p in time_series]
150154
n_frames = len(paths_ordered)
151155

156+
# Multiple files but no pattern match: treat as static scene (all at time 0, no time samples)
157+
is_static_merge = n_frames > 1 and not pattern_matched
158+
152159
self.log_info("Input: %d file(s), time steps: %s", n_frames, time_steps[:5])
153160
if n_frames > 5:
154161
self.log_info(" ... and %d more", n_frames - 5)
162+
if is_static_merge:
163+
self.log_info(
164+
"Filenames do not match time-series pattern; outputting static scene (no time samples)"
165+
)
155166
self.log_info("Output: %s", self.output_usd)
156167

157168
settings = ConversionSettings(
@@ -163,7 +174,7 @@ def run(self) -> str:
163174
separate_objects_by_cell_type=self.separate_by_cell_type,
164175
up_axis=self.up_axis,
165176
times_per_second=self.times_per_second,
166-
use_time_samples=True,
177+
use_time_samples=not is_static_merge,
167178
)
168179

169180
converter = VTKToUSDConverter(settings)
@@ -181,13 +192,22 @@ def run(self) -> str:
181192
material=default_material,
182193
extract_surface=self.extract_surface,
183194
)
195+
elif is_static_merge:
196+
stage = converter.convert_files_static(
197+
paths_ordered,
198+
self.output_usd,
199+
mesh_name=self.mesh_name,
200+
material=default_material,
201+
extract_surface=self.extract_surface,
202+
)
184203
else:
204+
# Load mesh sequence once for both validation and conversion (avoids double I/O)
205+
mesh_sequence = [
206+
read_vtk_file(p, extract_surface=self.extract_surface)
207+
for p in paths_ordered
208+
]
185209
# Optional: validate topology consistency across frames
186210
try:
187-
mesh_sequence = [
188-
read_vtk_file(p, extract_surface=self.extract_surface)
189-
for p in paths_ordered
190-
]
191211
report = validate_time_series_topology(mesh_sequence)
192212
if report.get("topology_changes"):
193213
self.log_warning(
@@ -197,13 +217,12 @@ def run(self) -> str:
197217
except Exception as e:
198218
self.log_debug("Time series validation skipped: %s", e)
199219

200-
stage = converter.convert_sequence(
201-
paths_ordered,
220+
stage = converter.convert_mesh_data_sequence(
221+
mesh_sequence,
202222
self.output_usd,
203223
mesh_name=self.mesh_name,
204224
time_codes=time_codes,
205225
material=default_material,
206-
extract_surface=self.extract_surface,
207226
)
208227

209228
# Post-process: apply chosen appearance to all meshes under /World/Meshes
@@ -215,6 +234,9 @@ def run(self) -> str:
215234
self.log_warning("No mesh prims found under /World/Meshes")
216235
return str(self.output_usd)
217236

237+
# Static merge has no time samples; pass None so only default time is used
238+
appearance_time_codes = None if is_static_merge else time_codes
239+
218240
self.log_info(
219241
"Applying appearance '%s' to %d mesh(es)", self.appearance, len(mesh_paths)
220242
)
@@ -225,7 +247,7 @@ def run(self) -> str:
225247
str(self.output_usd),
226248
mesh_path,
227249
self.solid_color,
228-
time_codes=time_codes,
250+
time_codes=appearance_time_codes,
229251
bind_vertex_color_material=True,
230252
)
231253

@@ -249,7 +271,6 @@ def run(self) -> str:
249271
self.log_warning(
250272
"No color primvar found for %s; skip colormap", mesh_path
251273
)
252-
primvar = self.colormap_primvar
253274
continue
254275
self.log_info(
255276
"Applying colormap to %s from primvar %s", mesh_path, primvar

0 commit comments

Comments
 (0)