Skip to content

Commit a45deed

Browse files
committed
Merge remote-tracking branch 'origin/vista3d' into vista3d-export
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
2 parents 3e4a84b + 24567b9 commit a45deed

File tree

8 files changed

+42
-29
lines changed

8 files changed

+42
-29
lines changed

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,45 @@ The **VISTA3D** is a foundation model trained systematically on 11,454 volumes e
2121
### Out-of box automatic segmentation
2222
For supported 127 classes, the model can perform highly accurate out-of-box segmentation. The fully automated process adopts a patch-based sliding-window inference and only requires a class prompt.
2323
Compared to supervised segmentation models trained on each dataset separately, VISTA3D showed comparable out-of-box performances and strong generalizability ('VISTA3D auto' in Table.1).
24-
<!-- <div align="center"> <img src="" width="800"/> </div> -->
24+
<!-- <div align="center"> <img src="assets/imgs/everything.gif" width="800"/> </div> -->
25+
<div align="center">
26+
<figure>
27+
<img
28+
src="assets/imgs/everything.gif">
29+
<figcaption> NIM Demo supports "Segment Everything" </figcaption>
30+
</figure>
31+
</div>
32+
33+
2534

2635
### Interactive editing
2736
The interactive segmentation is based on user-provided clicks. Each click point will impact a local 3D patch. User can either effectively refine the automatic results with clicks ('VISTA3D auto+point' in Table.1) or simply provide a click without specifying the target class ('VISTA3D point' in Table.1) .
2837
<!-- <div align="center"> <img src="" width="800"/> </div> -->
38+
<div align="center">
39+
<figure>
40+
<img
41+
src="assets/imgs/liver.gif">
42+
<figcaption> Specify a supported class and edit the automatic results </figcaption>
43+
</figure>
44+
</div>
45+
<div align="center">
46+
<figure>
47+
<img
48+
src="assets/imgs/unspecified.gif">
49+
<figcaption> Interactive supported class segmentation without specifying class </figcaption>
50+
</figure>
51+
</div>
2952

3053
### Zero-shot interactive segmentation
3154
VISTA3D is built to produce visually plausible segmentations on previously unseen classes.
3255
This capability makes the model even more flexible and accelerates practical segmentation data curation processes.
56+
<div align="center">
57+
<figure>
58+
<img
59+
src="assets/imgs/zeroshot.gif">
60+
<figcaption> Add a new unseen class and do annotation </figcaption>
61+
</figure>
62+
</div>
3363

3464
### Fine-tuning
3565
VISTA3D checkpoint showed improvements when finetuning in few-shot settings. Once a few annotated examples are provided, user can start finetune with the VISTA3D checkpoint.
@@ -98,7 +128,7 @@ Ask and answer questions on [MONAI VISTA's GitHub discussions tab](https://githu
98128

99129
## License
100130

101-
The codebase is under Apache 2.0 Licence.
131+
The codebase is under Apache 2.0 Licence. The model weight is under special NVIDIA license.
102132

103133
## Reference
104134

assets/imgs/everything.gif

95.3 MB
Loading

assets/imgs/liver.gif

4.61 MB
Loading

assets/imgs/unspecified.gif

1.97 MB
Loading

assets/imgs/zeroshot.gif

2.13 MB
Loading

scripts/debugger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ def on_button_click(event, ax=ax):
123123
print("-- segmenting ---")
124124
self.generate_mask()
125125
print("-- done ---")
126+
print("-- Note: Point only prompts will only do 128 cubic segmentation, a cropping artefact will be observed. ---")
127+
print("-- Note: Point without class will be treated as supported class, which has worse zero-shot ability. Try class > 132 to perform better zeroshot. ---")
128+
print("-- Note: CTRL + Right Click will be adding negative points. ---")
126129
print(
127130
"-- Note: Click points on different foreground class will cause segmentation conflicts. Clear first. ---"
128131
)
129132
print(
130133
"-- Note: Click points not matching class prompts will also cause confusion. ---"
131134
)
132-
print("-- Note: CTRL + Right Click will be adding negative points. ---")
135+
133136
self.update_slice(ax)
134137
# self.point_start = len(self.clicked_points)
135138

vista3d/modeling/segresnetds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _forward(self, x: torch.Tensor) -> list[torch.Tensor]:
238238

239239
if self.head_module is not None:
240240
outputs = self.head_module(outputs)
241+
241242
return outputs
242243

243244
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
@@ -463,7 +464,7 @@ def is_valid_shape(self, x):
463464

464465
def _forward(
465466
self, x: torch.Tensor, with_point, with_label
466-
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
467+
) -> Union[None, torch.Tensor, list[torch.Tensor]]:
467468
if self.preprocess is not None:
468469
x = self.preprocess(x)
469470

@@ -522,7 +523,7 @@ def _forward(
522523
return outputs, outputs_auto
523524

524525
def forward(
525-
self, x: torch.Tensor, with_point=True, with_label=True, # **kwargs
526+
self, x: torch.Tensor, with_point=True, with_label=True, **kwargs
526527
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
527528
return self._forward(x, with_point, with_label)
528529

vista3d/modeling/vista3d.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717
import torch.nn as nn
1818
from monai.utils import optional_import
19-
import time
2019

2120
from scripts.utils.trans_utils import convert_points_to_disc
2221
from scripts.utils.trans_utils import get_largest_connected_component_mask as lcc
@@ -42,8 +41,7 @@ def __init__(self, image_encoder, class_head, point_head, feature_size):
4241
)
4342
self.auto_freeze = False
4443
self.point_freeze = False
45-
self.engine = None
46-
44+
4745
def precompute_embedding(self, input_images):
4846
"""precompute image embedding, require sliding window inference"""
4947
raise NotImplementedError
@@ -205,8 +203,6 @@ def set_auto_grad(self, auto_freeze=False, point_freeze=False):
205203
param.requires_grad = not point_freeze
206204
self.point_freeze = point_freeze
207205

208-
209-
210206
def forward(
211207
self,
212208
input_images,
@@ -249,8 +245,6 @@ def forward(
249245
val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
250246
251247
"""
252-
time00 = time.time()
253-
254248
image_size = input_images.shape[-3:]
255249
device = input_images.device
256250
if point_coords is None and class_vector is None:
@@ -313,16 +307,13 @@ def forward(
313307
enable_all_tactics=True
314308
)
315309

316-
time0 = time.time()
317310
out, out_auto = self.image_encoder(
318311
x=input_images,
319312
with_point=point_coords is not None,
320313
with_label=class_vector is not None,
321314
)
322-
# torch.cuda.synchronize()
323-
# time1 = time.time()
324-
# print(f"Encoder Time: {time.time() - time0}, shape : {input_images.shape}, point: {point_coords is not None}")
325-
input_images = None
315+
input_images = None
316+
326317
# force releasing memories that set to None
327318
torch.cuda.empty_cache()
328319
if class_vector is not None:
@@ -333,19 +324,12 @@ def forward(
333324
dynamo=False,
334325
verbose=False,
335326
)
336-
# time2 = time.time()
337327
logits, _ = self.class_head(src=out_auto, class_vector=class_vector)
338-
# torch.cuda.synchronize()
339-
# print(f"Class Head Time: {time.time() - time2}")
340328

341329
if point_coords is not None:
342-
# time3 = time.time()
343330
point_logits = self.point_head(
344331
out, point_coords, point_labels, class_vector=prompt_class
345332
)
346-
# torch.cuda.synchronize()
347-
# print(f"Point Head Time: {time.time() - time3}")
348-
# time4 = time.time()
349333
if patch_coords is None:
350334
logits = self.gaussian_combine(
351335
logits,
@@ -360,8 +344,6 @@ def forward(
360344
logits = self.connected_components_combine(
361345
logits, point_logits, point_coords, point_labels, mapping_index
362346
)
363-
# torch.cuda.synchronize()
364-
# print(f"Combine Time: {time.time() - time4}")
365347
else:
366348
logits = NINF_VALUE + torch.zeros(
367349
[bs, 1, *image_size], device=device, dtype=out.dtype
@@ -378,9 +360,6 @@ def forward(
378360
mapping_index,
379361
)
380362

381-
# torch.cuda.synchronize()
382-
# print(f"Total time : {time.time() - time00} shape : {logits.shape}")
383-
384363
if kwargs.get("keep_cache", False) and class_vector is None:
385364
self.image_embeddings = out.detach()
386365
return logits

0 commit comments

Comments
 (0)