Skip to content

Commit 1c5939f

Browse files
Add STARC-9 dataset (#10)
Co-authored-by: PierreMarza <pierre.marza@gmail.com>
1 parent 4694631 commit 1c5939f

File tree

7 files changed

+221
-0
lines changed

7 files changed

+221
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
dataset_name: starc9
2+
nb_classes: 9
3+
base_data_folder: ${oc.env:THUNDER_BASE_DATA_FOLDER}/datasets/
4+
compatible_tasks: ["adversarial_attack", "alignment_scoring", "image_retrieval", "knn", "linear_probing", "pre_computing_embeddings", "simple_shot", "transformation_invariance", "zero_shot_vlm"]
5+
nb_train_samples: 630000
6+
nb_val_samples: 18000
7+
nb_test_samples: 54000
8+
md5sum: "3010519777b46827fdb16e656ed74975"
9+
image_sizes: [[256, 256]]
10+
mpp: 0.5
11+
cancer_type: colorectal
12+
classes: ["ADI", "LYM", "MUC", "MUS", "NCS", "NOR", "BLD", "FCT", "TUM"]
13+
class_to_id:
14+
ADI: 0
15+
LYM: 1
16+
MUC: 2
17+
MUS: 3
18+
NCS: 4
19+
NOR: 5
20+
BLD: 6
21+
FCT: 7
22+
TUM: 8
23+
id_to_class:
24+
0: ADI
25+
1: LYM
26+
2: MUC
27+
3: MUS
28+
4: NCS
29+
5: NOR
30+
6: BLD
31+
7: FCT
32+
8: TUM
33+
id_to_classname:
34+
0: adipose tissue
35+
1: lymphoid tissue
36+
2: mucin
37+
3: muscle
38+
4: necrosis
39+
5: normal mucosa
40+
6: blood
41+
7: fibroconnective tissue
42+
8: tumor

src/thunder/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
spider_colorectal,
1717
spider_skin,
1818
spider_thorax,
19+
starc9,
1920
tcga_crc_msi,
2021
tcga_tils,
2122
tcga_uniform,

src/thunder/datasets/data_splits.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def generate_splits(datasets: Union[List[str], str]) -> None:
3939
"spider_colorectal",
4040
"spider_skin",
4141
"spider_thorax",
42+
"starc9",
4243
]
4344
elif datasets[0] == "classification":
4445
datasets = [
@@ -58,6 +59,7 @@ def generate_splits(datasets: Union[List[str], str]) -> None:
5859
"spider_colorectal",
5960
"spider_skin",
6061
"spider_thorax",
62+
"starc9",
6163
]
6264
elif datasets[0] == "segmentation":
6365
datasets = [
@@ -104,6 +106,7 @@ def generate_splits_for_dataset(dataset_name: str) -> None:
104106
create_splits_spider_colorectal,
105107
create_splits_spider_skin,
106108
create_splits_spider_thorax,
109+
create_splits_starc9,
107110
create_splits_tcga_crc_msi,
108111
create_splits_tcga_tils,
109112
create_splits_tcga_uniform,
@@ -128,6 +131,7 @@ def generate_splits_for_dataset(dataset_name: str) -> None:
128131
"spider_colorectal": create_splits_spider_colorectal,
129132
"spider_skin": create_splits_spider_skin,
130133
"spider_thorax": create_splits_spider_thorax,
134+
"starc9": create_splits_starc9,
131135
# Segmentation
132136
"ocelot": create_splits_ocelot,
133137
"pannuke": create_splits_pannuke,

src/thunder/datasets/dataset/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from .spider_skin import create_splits_spider_skin, download_spider_skin
2727
from .spider_thorax import create_splits_spider_thorax, download_spider_thorax
28+
from .starc9 import create_splits_starc9, download_starc9
2829
from .tcga_crc_msi import create_splits_tcga_crc_msi, download_tcga_crc_msi
2930
from .tcga_tils import create_splits_tcga_tils, download_tcga_tils
3031
from .tcga_uniform import create_splits_tcga_uniform, download_tcga_uniform
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from typing import Dict, List, Tuple
2+
3+
CLASS_TO_ID = {
4+
"ADI": 0,
5+
"LYM": 1,
6+
"MUC": 2,
7+
"MUS": 3,
8+
"NCS": 4,
9+
"NOR": 5,
10+
"BLD": 6,
11+
"FCT": 7,
12+
"TUM": 8,
13+
}
14+
15+
VALID_EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp"}
16+
17+
18+
def download_starc9(root_folder: str) -> None:
19+
"""
20+
Download the STARC-9 dataset from Hugging Face and extract all zip files.
21+
22+
Final split mapping:
23+
- train: Training_data_normalized
24+
- val: Validation_data/STANFORD-CRC-HE-VAL-SMALL
25+
- test: Validation_data/STANFORD-CRC-HE-VAL-LARGE
26+
27+
CURATED-TCGA is intentionally ignored here.
28+
"""
29+
from huggingface_hub import snapshot_download
30+
31+
snapshot_download(
32+
repo_id="Path2AI/STARC-9",
33+
repo_type="dataset",
34+
local_dir=root_folder,
35+
local_dir_use_symlinks=False,
36+
)
37+
38+
extract_all_zips(root_folder)
39+
40+
41+
def extract_all_zips(root_dir: str) -> None:
42+
"""
43+
Recursively extract every .zip under root_dir into a folder with the same stem.
44+
"""
45+
import os
46+
from pathlib import Path
47+
48+
from ..utils import unzip_file
49+
50+
for current_root, _, files in os.walk(root_dir):
51+
for file_name in files:
52+
if not file_name.lower().endswith(".zip"):
53+
continue
54+
55+
unzip_file(
56+
os.path.join(current_root, file_name),
57+
current_root,
58+
)
59+
60+
# Renaming folder extracted from STANFORD-CRC-HE-VAL-LARGE-NORMALIZED.zip
61+
if file_name == "STANFORD-CRC-HE-VAL-LARGE-NORMALIZED.zip":
62+
os.rename(
63+
os.path.join(current_root, "NORMALIZED"),
64+
os.path.join(current_root, "STANFORD-CRC-HE-VAL-LARGE"),
65+
)
66+
67+
68+
def collect_images_from_class_root(
69+
class_root: str,
70+
) -> Tuple[List[str], List[int], Dict[str, int]]:
71+
"""
72+
Read all images from a directory structured like:
73+
class_root/
74+
ADI/
75+
LYM/
76+
...
77+
"""
78+
from pathlib import Path
79+
80+
images: List[str] = []
81+
labels: List[int] = []
82+
83+
class_root_path = Path(class_root)
84+
if not class_root_path.exists():
85+
raise FileNotFoundError(f"Class root does not exist: {class_root}")
86+
87+
missing_classes = [c for c in CLASS_TO_ID if not (class_root_path / c).exists()]
88+
if missing_classes:
89+
raise FileNotFoundError(
90+
f"Missing expected class folders under {class_root}: {missing_classes}"
91+
)
92+
93+
for class_name, class_id in CLASS_TO_ID.items():
94+
class_dir = class_root_path / class_name
95+
for img_path in sorted(class_dir.rglob("*")):
96+
if img_path.is_file() and img_path.suffix.lower() in VALID_EXTS:
97+
images.append(str(img_path.resolve()))
98+
labels.append(class_id)
99+
100+
return images, labels
101+
102+
103+
def create_splits_starc9(base_folder: str, dataset_cfg: dict) -> None:
104+
"""
105+
Generating data splits for the STARC-9 dataset.
106+
107+
:param base_folder: path to the main folder storing datasets.
108+
:param dataset_cfg: dataset-specific config.
109+
"""
110+
import os
111+
112+
from ...utils.constants import UtilsConstants
113+
from ...utils.utils import set_seed
114+
from ..data_splits import (
115+
check_dataset,
116+
create_few_shot_training_data,
117+
init_dict,
118+
save_dict,
119+
)
120+
121+
# Setting the random seed
122+
set_seed(UtilsConstants.DEFAULT_SEED.value)
123+
124+
# Initializing dict
125+
starc9_data_splits = init_dict()
126+
127+
# Getting folder paths
128+
dataset_root = os.path.join(base_folder, "starc9")
129+
train_root = os.path.join(dataset_root, "Training_data_normalized")
130+
val_root = os.path.join(
131+
dataset_root,
132+
"Validation_data",
133+
"STANFORD-CRC-HE-VAL-SMALL",
134+
)
135+
test_root = os.path.join(
136+
dataset_root,
137+
"Validation_data",
138+
"STANFORD-CRC-HE-VAL-LARGE",
139+
)
140+
141+
# Collecting data
142+
train_images, train_labels = collect_images_from_class_root(train_root)
143+
val_images, val_labels = collect_images_from_class_root(val_root)
144+
test_images, test_labels = collect_images_from_class_root(test_root)
145+
146+
# Updating dict
147+
starc9_data_splits["train"]["images"] = train_images
148+
starc9_data_splits["train"]["labels"] = train_labels
149+
starc9_data_splits["val"]["images"] = val_images
150+
starc9_data_splits["val"]["labels"] = val_labels
151+
starc9_data_splits["test"]["images"] = test_images
152+
starc9_data_splits["test"]["labels"] = test_labels
153+
154+
# Few-shot training data
155+
starc9_data_splits = create_few_shot_training_data(starc9_data_splits)
156+
157+
# Checking dataset characteristics
158+
check_dataset(
159+
starc9_data_splits,
160+
dataset_cfg,
161+
base_folder,
162+
)
163+
164+
# Saving dict
165+
save_dict(
166+
starc9_data_splits, os.path.join(base_folder, "data_splits", "starc9.json")
167+
)

src/thunder/datasets/download.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def download_datasets(datasets: Union[List[str], str], make_splits: bool = False
2727
* spider_colorectal
2828
* spider_skin
2929
* spider_thorax
30+
* starc9
3031
* tcga_crc_msi
3132
* tcga_tils
3233
* tcga_uniform
@@ -65,6 +66,7 @@ def download_datasets(datasets: Union[List[str], str], make_splits: bool = False
6566
"spider_colorectal",
6667
"spider_skin",
6768
"spider_thorax",
69+
"starc9",
6870
"tcga_crc_msi",
6971
"tcga_tils",
7072
"tcga_uniform",
@@ -84,6 +86,7 @@ def download_datasets(datasets: Union[List[str], str], make_splits: bool = False
8486
"spider_colorectal",
8587
"spider_skin",
8688
"spider_thorax",
89+
"starc9",
8790
"tcga_crc_msi",
8891
"tcga_tils",
8992
"tcga_uniform",
@@ -160,5 +163,7 @@ def download_dataset(dataset: str):
160163
download_spider_skin(root_folder)
161164
elif dataset == "spider_thorax":
162165
download_spider_thorax(root_folder)
166+
elif dataset == "starc9":
167+
download_starc9(root_folder)
163168
else:
164169
raise ValueError(f"Dataset {dataset} is not supported.")

src/thunder/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class DatasetConstants(Enum):
6262
"spider_colorectal",
6363
"spider_skin",
6464
"spider_thorax",
65+
"starc9",
6566
"tcga_crc_msi",
6667
"tcga_tils",
6768
"tcga_uniform",

0 commit comments

Comments
 (0)