Skip to content

Commit 7cfd7d9

Browse files
authored
Add cvpr workshop (#52)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] In-line docstrings updated. --------- Signed-off-by: heyufan1995 <heyufan1995@gmail.com> Signed-off-by: Yufan He <59374597+heyufan1995@users.noreply.github.com>
1 parent b65b3d1 commit 7cfd7d9

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

vista3d/cvpr_workshop/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,27 @@ limitations under the License.
1515
This repository is written for the "CVPR 2025: Foundation Models for Interactive 3D Biomedical Image Segmentation"([link](https://www.codabench.org/competitions/5263/)) challenge. It
1616
is based on MONAI 1.4. Many of the functions in the main VISTA3D repository are moved to MONAI 1.4 and this simplified folder will directly use components from MONAI.
1717

18-
It is overly simplied to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed.
18+
19+
It is simplified to train interactive segmentation models across different modalities. The sophisticated transforms and recipes used for VISTA3D are removed. The finetuned VISTA3D checkpoint on the challenge subsets is available [here](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing)
1920

2021
# Setup
2122
```
2223
pip install -r requirements.txt
2324
```
2425

2526
# Training
26-
Download VISTA3D pretrained checkpoint or from scratch. Generate a json list that contains your traning data.
27+
Download the challenge subsets finetuned [checkpoint](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) or VISTA3D original [checkpoint]((https://drive.google.com/file/d/1DRYA2-AI-UJ23W1VbjqHsnHENGi0ShUl/view?usp=sharing)). Generate a json list that contains your traning data and update the json file path in the script.
2728
```
2829
torchrun --nnodes=1 --nproc_per_node=8 train_cvpr.py
2930
```
3031

3132
# Inference
32-
We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/))
33+
You can directly download the [docker file](https://drive.google.com/file/d/1r2KvHP_30nHR3LU7NJEdscVnlZ2hTtcd/view?usp=sharing) for the challenge baseline.
34+
We provide a Dockerfile to satisfy the challenge format. For more details, refer to the [challenge website]((https://www.codabench.org/competitions/5263/)).
35+
```
36+
docker build -t vista3d:latest .
37+
docker save -o vista3d.tar.gz vista3d:latest
38+
```
39+
3340

3441

vista3d/cvpr_workshop/train_cvpr.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,24 @@ def __getitem__(self, idx):
104104
return data
105105
# Training function
106106
def train():
107+
json_file = "subset.json" # Update with your JSON file
107108
epoch_number = 100
108-
start_epoch = 30
109+
start_epoch = 0
109110
lr = 2e-5
110111
checkpoint_dir = "checkpoints"
112+
start_checkpoint = '/workspace/CPRR25_vista3D_model_final_10percent_data.pth'
111113
os.makedirs(checkpoint_dir, exist_ok=True)
112114
dist.init_process_group(backend="nccl")
113115
world_size = int(os.environ["WORLD_SIZE"])
114116
local_rank = int(os.environ["LOCAL_RANK"])
115117
torch.cuda.set_device(local_rank)
116118
device = torch.device(f"cuda:{local_rank}")
117-
json_file = "subset.json" # Update with your JSON file
118119
dataset = NPZDataset(json_file)
119120
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=local_rank)
120121
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=32)
121122
model = vista3d132(in_channels=1).to(device)
122-
# pretrained_ckpt = torch.load('/workspace/VISTA/vista3d/bundles/vista3d/models/model.pt', map_location=device)
123-
pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
123+
pretrained_ckpt = torch.load(start_checkpoint, map_location=device)
124+
# pretrained_ckpt = torch.load(os.path.join(checkpoint_dir, f"model_epoch{start_epoch}.pth"))
124125
model = DDP(model, device_ids=[local_rank], find_unused_parameters=True)
125126
model.load_state_dict(pretrained_ckpt['model'], strict=True)
126127
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1.0e-05)

0 commit comments

Comments
 (0)