-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathensemble_voting_rule_semantic.py
More file actions
114 lines (91 loc) · 4.26 KB
/
ensemble_voting_rule_semantic.py
File metadata and controls
114 lines (91 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import imageio
import numpy as np
import os
from pathlib import Path
from postProc import postProc
from skimage.filters import threshold_otsu
#Binarização das imagens
def binarize(im, thr=None):
if thr is None:
thr = threshold_otsu(im)
return (im > thr).astype(np.uint8)
#root_dir = pasta onde estão todos os resultados
#class_list = lista de classes ("healthy", "mild"...)
#postProc = função de pós-processamento
def process_all(root_dir, class_list, postProc, original_images_dir, output_dir, weight_vector, threshold_percent):
root = Path(root_dir)
output_dir = Path(output_dir)
#Descobrir todas as pastas de métodos automaticamente
#Cada uma deve conter /masks/<classe>/
method_dirs = [p for p in root.iterdir() if p.is_dir()]
print("Métodos encontrados:")
for m in method_dirs:
print(" -", m.name)
#Para cada classe
for cls in class_list:
print(f"\n>>> Processando classe: {cls}")
#Montar para cada método a lista de imagens disponíveis
method_files = {}
for method_dir in method_dirs:
mask_dir = method_dir / "masks" / cls
if not mask_dir.exists():
print(f"Aviso: {mask_dir} não existe. Método será ignorado.")
continue
#Aceita qualquer extensão
files = list(mask_dir.glob("*"))
method_files[method_dir.name] = {f.stem: f for f in files}
#Descobrir nomes de imagem que existem em TODOS os métodos
common_names = set.intersection(
*[set(method_files[m].keys()) for m in method_files]
)
print(f"{len(common_names)} imagens em comum encontradas.")
#Verificar se o número de pesos coicide com o número de métodos
if len(weight_vector) != len(method_files):
raise ValueError(f"Número de pesos ({len(weight_vector)}) não coincide com número de métodos ({len(method_files)}).")
weight_vector = np.array(weight_vector, dtype=float)
sum_weights = weight_vector.sum()
thr_value = threshold_percent * sum_weights
print(f"Threshold = {threshold_percent*100:.1f}% do total → {thr_value:.3f}")
#Processar cada uma
for name in sorted(common_names):
print(f" - Processando {name}")
masks = []
#Ler as máscaras
for w_method, (method, weight) in zip(weight_vector, method_files.items()):
path = method_files[method][name]
img = imageio.imread(path)
if img.ndim == 3:
img = img[..., 0] #Converte de RGB para 1 canal
masks.append(binarize(img))
final = np.zeros_like(masks[0], dtype=float)
for w, m in zip(weight_vector, masks):
final += w * m
#Aplicar limiar
final2 = (final > thr_value).astype(np.uint8) * 255
#Ler imagem original
original_path = None
for ext in [".jpg", ".png", ".tif", ".tiff"]:
p = Path(original_images_dir) / cls / f"{name}{ext}"
if p.exists():
original_path = p
break
if original_path is None:
print(f"Original não encontrado para {name}. Pulando.")
continue
orIm = imageio.imread(original_path)
finalMaskPost, finalImPost = postProc(final2, orIm)
# Escrita dos arquivos
out_mask = output_dir / "masks" / cls
out_seg = output_dir / "segmentation" / cls
out_mask.mkdir(parents=True, exist_ok=True)
out_seg.mkdir(parents=True, exist_ok=True)
imageio.imwrite(out_mask / f"{name}.png", finalMaskPost)
imageio.imwrite(out_seg / f"{name}.png", finalImPost)
if __name__ == "__main__":
root_dir = "aaaaa_results_test"
class_list = ["healthy", "mild", "moderate", "severe"]
weight_vector = [1,1,1,1,1,1]
original_images_dir = "datasets/dysplasia"
output_dir = "ENSEMBLE_dynamic_output"
threshold_percent = 0.40
process_all(root_dir=root_dir, class_list=class_list, postProc=postProc, original_images_dir=original_images_dir, output_dir=output_dir, weight_vector=weight_vector, threshold_percent=threshold_percent)