Skip to content

Commit 5c2c86b

Browse files
committed
update readme
Signed-off-by: heyufan1995 <heyufan1995@gmail.com>
1 parent 8bb7572 commit 5c2c86b

File tree

4 files changed

+103
-62
lines changed

4 files changed

+103
-62
lines changed

vista3d/README.md

Lines changed: 101 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ limitations under the License.
1313

1414
# MONAI **V**ersatile **I**maging **S**egmen**T**ation and **A**nnotation
1515
[[`Paper`](https://arxiv.org/pdf/2406.05285)] [[`Demo`](https://build.nvidia.com/nvidia/vista-3d)] [[`Checkpoint`]](https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)
16+
<div align="center"> <img src="./assets/imgs/workflow.png" width="100%"/> </div>
1617

1718
## News!
1819
[03/12/2025] We provide VISTA3D as a baseline for the challenge "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)). The simplified code based on MONAI 1.4 is provided in the [here](./cvpr_workshop/).
@@ -21,7 +22,7 @@ limitations under the License.
2122
## Overview
2223

2324
The **VISTA3D** is a foundation model trained systematically on 11,454 volumes encompassing 127 types of human anatomical structures and various lesions. It provides accurate out-of-the-box segmentation that matches state-of-the-art supervised models which are trained on each dataset. The model also achieves state-of-the-art zero-shot interactive segmentation in 3D, representing a promising step toward developing a versatile medical image foundation model.
24-
<div align="center"> <img src="./assets/imgs/scores.png" width="800"/> </div>
25+
2526

2627
### Out-of box automatic segmentation
2728
For supported 127 classes, the model can perform highly accurate out-of-box segmentation. The fully automated process adopts a patch-based sliding-window inference and only requires a class prompt.
@@ -66,52 +67,124 @@ This capability makes the model even more flexible and accelerates practical seg
6667
</figure>
6768
</div>
6869

69-
### Fine-tuning
70-
VISTA3D checkpoint showed improvements when finetuning in few-shot settings. Once a few annotated examples are provided, user can start finetune with the VISTA3D checkpoint.
71-
<div align="center"> <img src="./assets/imgs/finetune.png" width="600"/> </div>
7270

7371
## Usage
7472

75-
### Installation
73+
## Installation
7674
To perform inference locally with a debugger GUI, simply install
77-
```
78-
git clone https://github.com/Project-MONAI/VISTA.git;
79-
cd ./VISTA/vista3d;
75+
```bash
76+
git clone https://github.com/Project-MONAI/VISTA.git
77+
cd ./VISTA/vista3d
78+
conda create -n -y vista3d python=3.9
79+
conda activate vista3d
8080
pip install -r requirements.txt
8181
```
8282
Download the [model checkpoint](https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing) and save it at ./models/model.pt.
8383

84-
### Inference
85-
The [NIM Demo (VISTA3D NVIDIA Inference Microservices)](https://build.nvidia.com/nvidia/vista-3d) does not support medical data upload due to legal concerns.
86-
We provide scripts for inference locally. The automatic segmentation label definition can be found at [label_dict](./data/jsons/label_dict.json). For exact number of supported automatic segmentation class and the reason, please to refer to [issue](https://github.com/Project-MONAI/VISTA/issues/41).
87-
88-
#### MONAI Bundle
89-
90-
For automatic segmentation and batch processing, we highly recommend using the MONAI model zoo. The [MONAI bundle](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista3d) wraps VISTA3D and provides a unified API for inference, and the [NIM Demo](https://build.nvidia.com/nvidia/vista-3d) deploys the bundle with an interactive front-end. Although NIM Demo cannot run locally, the bundle is available and can run locally. The following command will download the vista3d standalone bundle. The documentation in the bundle contains a detailed explanation for finetuning and inference.
84+
## Inference
85+
The current repo is the research codebase for the CVPR2025 paper, which is built on MONAI1.3. We converted the model into [MONAI bundle](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista3d) with improved GPU utilization and speed (the backend for the [demo](https://build.nvidia.com/nvidia/vista-3d)). The automatic segmentation label definition can be found at [label_dict](./data/jsons/label_dict.json). For exact number of supported automatic segmentation class and the reason, please to refer to [issue](https://github.com/Project-MONAI/VISTA/issues/41).
86+
<div align="center"> <img src="./assets/imgs/scores.png" width="800"/> </div>
9187

88+
### 1. Recommend: MONAI Bundle (model zoo)
89+
90+
```bash
91+
# use the same conda env as this repo
92+
conda activate vista3d
93+
pip install monai==1.4
94+
git clone https://github.com/Project-MONAI/model-zoo.git
95+
mv model-zoo/models/vista3d vista3dbundle & rm -rf model-zoo
96+
cd vista3dbundle
97+
mkdir models
98+
# minor model weights naming conversion due to monai version change
99+
wget -O models/model.pt https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt
100+
```
101+
MONAI bundle accepts multiple json config files and input arguments. The latter configs/arguments will overide the previous configs/arguments if they have overlapping keys.
102+
```python
103+
# Automatic Segment everything
104+
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz'}
105+
```
106+
```python
107+
# Automatic Segment specific class
108+
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','label_prompt':[3]}
109+
```
110+
```python
111+
# Interactive segmentation
112+
# Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]]. Point labels can only be -1(ignore), 0(negative), 1(positive) and 2(negative for special overlaped class like tumor), 3(positive for special class). Only supporting 1 class per inference. The output 255 represents NaN value which means not processed region.
113+
python -m monai.bundle run --config_file configs/inference.json --input_dict "{'image':'spleen_03.nii.gz','points':[[128,128,16], [100,100,16]],'point_labels':[1, 0]}"
92114
```
93-
pip install "monai[fire]"
94-
python -m monai.bundle download "vista3d" --bundle_dir "bundles/"
115+
```python
116+
# Automatic Batch segmentation for the whole folder
117+
python -m monai.bundle run --config_file="['configs/inference.json', 'configs/batch_inference.json']" --input_dir="/data/Task09_Spleen/imagesTr" --output_dir="./eval_task09"
118+
```
119+
```python
120+
# Automatic Batch segmentation for the whole folder with multi-gpu support. mgpu_inference.json is below.
121+
python -m monai.bundle run --config_file="['configs/inference.json', 'configs/batch_inference.json', 'configs/mgpu_inference.json']" --input_dir="/data/Task09_Spleen/imagesTr" --output_dir="./eval_task09"
122+
```
123+
<details>
124+
<summary><b>Click to see mgpu_inference.json</b></summary>
125+
126+
```json
127+
{
128+
"device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
129+
"network": {
130+
"_target_": "torch.nn.parallel.DistributedDataParallel",
131+
"module": "$@network_def.to(@device)",
132+
"device_ids": [
133+
"@device"
134+
]
135+
},
136+
"sampler": {
137+
"_target_": "DistributedSampler",
138+
"dataset": "@dataset",
139+
"even_divisible": false,
140+
"shuffle": false
141+
},
142+
"dataloader#sampler": "@sampler",
143+
"initialize": [
144+
"$import torch.distributed as dist",
145+
"$dist.is_initialized() or dist.init_process_group(backend='nccl')",
146+
"$torch.cuda.set_device(@device)"
147+
],
148+
"run": [
149+
"$@evaluator.run()"
150+
],
151+
"finalize": [
152+
"$dist.is_initialized() and dist.destroy_process_group()"
153+
]
154+
}
155+
```
156+
</details>
157+
158+
### 1.1 Overlapped classes and postprocessing with [ShapeKit](https://arxiv.org/pdf/2506.24003)
159+
VISTA3D is trained with binary segmentation, and may produce false positives due to weak false positive supervision. ShapeKit solves this problem with sophisticated postprocessing. ShapeKit requires segmentation mask for each class. VISTA3D by default performs argmax and collaps overlapping classes. Change the `monai.apps.vista3d.transforms.VistaPostTransformd` in `inference.json` to save each class segmentation as a separate channel. Then follow [ShapeKit](https://github.com/BodyMaps/ShapeKit) codebase for processing.
160+
```json
161+
{
162+
"_target_": "Activationsd",
163+
"sigmoid": true,
164+
"keys": "pred"
165+
},
95166
```
96167

97-
#### Debugger
168+
### 2. VISTA3D Research repository (this repo)
98169

99170
We provide the `infer.py` script and its light-weight front-end `debugger.py`. User can directly lauch a local interface for both automatic and interactive segmentation.
100171

101-
```
172+
```bash
102173
python -m scripts.debugger run
103174
```
104175
or directly call infer.py to generate automatic segmentation. To segment a liver (label_prompt=1 as defined in [label_dict](./data/jsons/label_dict.json)), run
105-
```
176+
```bash
106177
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --label_prompt "[1]" --save_mask true
107178
```
108179
To segment everything, run
109-
```
180+
```bash
110181
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer_everything --image_file 'example-1.nii.gz'
111182
```
183+
To segment based on point clicks, provide `point` and `point_label`.
184+
```bash
185+
export CUDA_VISIBLE_DEVICES=0; python -m scripts.infer --config_file 'configs/infer.yaml' - infer --image_file 'example-1.nii.gz' --point "[[128,128,16],[100,100,6]]" --point_label "[1,0]" --save_mask true
186+
```
112187
The output path and other configs can be changed in the `configs/infer.yaml`.
113-
114-
115188
```
116189
NOTE: `infer.py` does not support `lung`, `kidney`, and `bone` class segmentation while MONAI bundle supports those classes. MONAI bundle uses better memory management and will not easily face OOM issue.
117190
```
@@ -146,6 +219,8 @@ For zero-shot, we perform iterative point sampling. To create a new zero-shot ev
146219
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_point_iterative run --config_file "['configs/zeroshot_eval/infer_iter_point_hcc.yaml']"
147220
```
148221
### Finetune
222+
VISTA3D checkpoint showed improvements when finetuning in few-shot settings. Once a few annotated examples are provided, user can start finetune with the VISTA3D checkpoint.
223+
<div align="center"> <img src="./assets/imgs/finetune.png" width="600"/> </div>
149224
For finetuning, user need to change `label_set` and `mapped_label_set` in the json config, where `label_set` matches the index values in the groundtruth files. The `mapped_label_set` can be random selected but we recommend pick the most related global index defined in [label_dict](./data/jsons/label_dict.json). User should modify the transforms, resolutions, patch sizes e.t.c regarding to their dataset for optimal finetuning performances, we recommend using configs generated by auto3dseg. The learning rate 5e-5 should be good enough for finetuning purposes.
150225
```
151226
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.train_finetune run --config_file "['configs/finetune/train_finetune_word.yaml']"
@@ -155,40 +230,6 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node
155230
Note: MONAI bundle also provides a unified API for finetuning, but the results in the table and paper are from this research repository.
156231
```
157232

158-
### NEW! [SAM2 Benchmark Tech Report](https://arxiv.org/abs/2408.11210)
159-
We provide scripts to run SAM2 evaluation. Modify SAM2 source code to support background remove: Add `z_slice` to `sam2_video_predictor.py`. Require SAM2 package [installation](https://github.com/facebookresearch/segment-anything-2)
160-
```
161-
@torch.inference_mode()
162-
def init_state(
163-
self,
164-
video_path,
165-
offload_video_to_cpu=False,
166-
offload_state_to_cpu=False,
167-
async_loading_frames=False,
168-
z_slice=None
169-
):
170-
"""Initialize a inference state."""
171-
images, video_height, video_width = load_video_frames(
172-
video_path=video_path,
173-
image_size=self.image_size,
174-
offload_video_to_cpu=offload_video_to_cpu,
175-
async_loading_frames=async_loading_frames,
176-
)
177-
if z_slice is not None:
178-
images = images[z_slice]
179-
```
180-
Run evaluation
181-
```
182-
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7;torchrun --nnodes=1 --nproc_per_node=8 -m scripts.validation.val_multigpu_sam2_point_iterative run --config_file "['configs/supported_eval/infer_sam2_point.yaml']" --saliency False --dataset_name 'Task06'
183-
```
184-
<div align="center">
185-
<figure>
186-
<img
187-
src="assets/imgs/sam2.png">
188-
<figcaption> Initial comparison with SAM2's zero-shot performance. </figcaption>
189-
</figure>
190-
</div>
191-
192233

193234
## Community
194235

@@ -205,10 +246,10 @@ The codebase is under Apache 2.0 Licence. The model weight is released under [NV
205246

206247
```
207248
@article{he2024vista3d,
208-
title={VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography},
249+
title={VISTA3D: A Unified Segmentation Foundation Model For 3D Medical Imaging},
209250
author={He, Yufan and Guo, Pengfei and Tang, Yucheng and Myronenko, Andriy and Nath, Vishwesh and Xu, Ziyue and Yang, Dong and Zhao, Can and Simon, Benjamin and Belue, Mason and others},
210-
journal={arXiv preprint arXiv:2406.05285},
211-
year={2024}
251+
journal={CVPR},
252+
year={2025}
212253
}
213254
```
214255

vista3d/assets/imgs/workflow.png

387 KB
Loading

vista3d/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ nibabel==5.2.1
55
numpy==1.24.4
66
Pillow==10.4.0
77
PyYAML==6.0.2
8-
scipy==1.14.0
8+
scipy
99
scikit-image==0.24.0
1010
torch==2.0.1
1111
tqdm==4.66.2

vista3d/scripts/infer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
9696
self.model = model.to(self.device)
9797

9898
pretrained_ckpt = torch.load(ckpt_name, map_location=self.device)
99-
self.model.load_state_dict(pretrained_ckpt, strict=False)
99+
self.model.load_state_dict(pretrained_ckpt, strict=True)
100100
logger.debug(f"[debug] checkpoint {ckpt_name:s} loaded")
101101
post_transforms = [
102102
VistaPostTransform(keys="pred"),

0 commit comments

Comments
 (0)