11import os
22
33import gradio as gr
4+ from PIL import Image
5+ import logging
6+ from zipfile import ZipFile
47
58from .inference import run_model
69from .utils import load_pred_volume_to_numpy
@@ -15,50 +18,47 @@ def __init__(
1518 cwd : str = "/home/user/app/" ,
1619 share : int = 1 ,
1720 ):
21+ self .file_output = None
22+ self .model_selector = None
23+ self .stripped_cb = None
24+ self .registered_cb = None
25+ self .run_btn = None
26+ self .slider = None
27+ self .download_file = None
28+
1829 # global states
1930 self .images = []
2031 self .pred_images = []
21-
22- # @TODO: This should be dynamically set based on chosen volume size
23- self .nb_slider_items = 512
32+ self .image_boxes = []
2433
2534 self .model_name = model_name
2635 self .cwd = cwd
2736 self .share = share
2837
29- self .class_name = "meningioma " # default
38+ self .class_name = "tumorcore " # default
3039 self .class_names = {
31- "meningioma " : "MRI_Meningioma " ,
32- "lower-grade-glioma " : "MRI_LGGlioma " ,
33- "metastasis " : "MRI_Metastasis " ,
34- "glioblastoma " : "MRI_GBM " ,
40+ "tumorcore " : "MRI_TumorCore " ,
41+ "NETC " : "MRI_Necrosis " ,
42+ "residual-tumor " : "MRI_TumorCE_Postop " ,
43+ "cavity " : "MRI_Cavity " ,
3544 "brain" : "MRI_Brain" ,
3645 }
3746
3847 self .result_names = {
39- "meningioma " : "Tumor" ,
40- "lower-grade-glioma " : "Tumor " ,
41- "metastasis " : "Tumor" ,
42- "glioblastoma " : "Tumor " ,
48+ "tumorcore " : "Tumor" ,
49+ "NETC " : "NETC " ,
50+ "residual-tumor " : "Tumor" ,
51+ "cavity " : "Cavity " ,
4352 "brain" : "Brain" ,
4453 }
4554
46- # define widgets not to be rendered immediately, but later on
47- self .slider = gr .Slider (
48- minimum = 1 ,
49- maximum = self .nb_slider_items ,
50- value = 1 ,
51- step = 1 ,
52- label = "Which 2D slice to show" ,
53- interactive = True ,
54- )
55-
5655 self .volume_renderer = gr .Model3D (
5756 clear_color = [0.0 , 0.0 , 0.0 , 0.0 ],
5857 label = "3D Model" ,
5958 visible = True ,
6059 elem_id = "model-3d" ,
61- ).style (height = 512 )
60+ height = 512 ,
61+ )
6262
6363 def set_class_name (self , value ):
6464 print ("Changed task to:" , value )
@@ -70,35 +70,97 @@ def combine_ct_and_seg(self, img, pred):
7070 def upload_file (self , file ):
7171 return file .name
7272
73- def process (self , mesh_file_name ):
73+ def process (self , mesh_file_name , stripped_inputs_status : bool = False ):
7474 path = mesh_file_name .name
7575 run_model (
7676 path ,
7777 model_path = os .path .join (self .cwd , "resources/models/" ),
7878 task = self .class_names [self .class_name ],
7979 name = self .result_names [self .class_name ],
80+ stripped_inputs_status = stripped_inputs_status ,
8081 )
8182 nifti_to_glb ("prediction.nii.gz" )
8283
8384 self .images = load_to_numpy (path )
84- # @TODO. Dynamic update of the slider does not seem to work like this
85- # self.nb_slider_items = len(self.images)
86- # self.slider.update(value=int(self.nb_slider_items/2), maximum=self.nb_slider_items)
8785
8886 self .pred_images = load_pred_volume_to_numpy ("./prediction.nii.gz" )
89- return "./prediction.obj"
87+ slider = gr .Slider (
88+ minimum = 0 ,
89+ maximum = len (self .images ) - 1 ,
90+ value = int (len (self .images ) / 2 ),
91+ step = 1 ,
92+ label = "Which 2D slice to show" ,
93+ interactive = True ,
94+ )
95+
96+ return "./prediction.obj" , slider
9097
9198 def get_img_pred_pair (self , k ):
92- k = int (k ) - 1
93- # @TODO. Will the duplicate the last slice to fill up, since slider not adjustable right now
94- if k >= len (self .images ):
95- k = len (self .images ) - 1
96- out = [gr .AnnotatedImage .update (visible = False )] * self .nb_slider_items
97- out [k ] = gr .AnnotatedImage .update (
98- self .combine_ct_and_seg (self .images [k ], self .pred_images [k ]),
99- visible = True ,
100- )
101- return out
99+ img = self .images [k ]
100+ img_pil = Image .fromarray (img )
101+ seg_list = []
102+ seg_list .append ((self .pred_images [k ], self .class_name ))
103+ return img_pil , seg_list
104+
105+ def setup_interface_inputs (self ):
106+ with gr .Row ():
107+ with gr .Column ():
108+ self .file_output = gr .File (file_count = "single" , elem_id = "upload" )
109+
110+ with gr .Column ():
111+ self .model_selector = gr .Dropdown (
112+ list (self .class_names .keys ()),
113+ label = "Segmentation task" ,
114+ info = "Select the segmentation model to run" ,
115+ multiselect = False ,
116+ # size="sm",
117+ )
118+
119+ with gr .Column ():
120+ with gr .Row ():
121+ self .stripped_cb = gr .Checkbox (label = "Stripped inputs" )
122+ self .registered_cb = gr .Checkbox (label = "Co-registered inputs" )
123+ with gr .Row ():
124+ self .run_btn = gr .Button ("Run segmentation" , scale = 1 )
125+
126+ def setup_interface_outputs (self ):
127+ with gr .Row ():
128+ with gr .Group ():
129+ with gr .Column ():
130+ t = gr .AnnotatedImage (
131+ visible = True ,
132+ elem_id = "model-2d" ,
133+ color_map = {self .class_name : "#ffae00" },
134+ height = 512 ,
135+ width = 512 ,
136+ )
137+
138+ self .slider = gr .Slider (
139+ minimum = 0 ,
140+ maximum = 1 ,
141+ value = 0 ,
142+ step = 1 ,
143+ label = "Which 2D slice to show" ,
144+ interactive = True ,
145+ )
146+
147+ self .slider .change (fn = self .get_img_pred_pair , inputs = self .slider , outputs = t )
148+
149+ with gr .Group ():
150+ self .volume_renderer .render ()
151+ self .download_btn = gr .DownloadButton (label = "Download results" , visible = False )
152+ self .download_file = gr .File (label = "Download Zip" , interactive = True , visible = False )
153+
154+ def package_results (self ):
155+ """Generates text files and zips them."""
156+ output_dir = "temp_output"
157+ os .makedirs (output_dir , exist_ok = True )
158+
159+ zip_filename = os .path .join (output_dir , "generated_files.zip" )
160+ with ZipFile (zip_filename , 'w' ) as zf :
161+ zf .write ("./prediction.nii.gz" )
162+
163+ return zip_filename
102164
103165 def run (self ):
104166 css = """
@@ -114,66 +176,29 @@ def run(self):
114176 }
115177 """
116178 with gr .Blocks (css = css ) as demo :
117- with gr .Row ():
118- file_output = gr .File (file_count = "single" , elem_id = "upload" )
119- file_output .upload (self .upload_file , file_output , file_output )
120-
121- model_selector = gr .Dropdown (
122- list (self .class_names .keys ()),
123- label = "Segmentation task" ,
124- info = "Select the preoperative segmentation model to run" ,
125- multiselect = False ,
126- size = "sm" ,
127- )
128- model_selector .input (
129- fn = lambda x : self .set_class_name (x ),
130- inputs = model_selector ,
131- outputs = None ,
132- )
133-
134- run_btn = gr .Button ("Run segmentation" ).style (
135- full_width = False , size = "lg"
136- )
137- run_btn .click (
138- fn = lambda x : self .process (x ),
139- inputs = file_output ,
140- outputs = self .volume_renderer ,
141- )
142-
179+ # Define the interface components first
180+ self .setup_interface_inputs ()
143181 with gr .Row ():
144182 gr .Examples (
145183 examples = [
146184 os .path .join (self .cwd , "t1gd.nii.gz" ),
147185 ],
148- inputs = file_output ,
149- outputs = file_output ,
186+ inputs = self . file_output ,
187+ outputs = self . file_output ,
150188 fn = self .upload_file ,
151189 cache_examples = True ,
152190 )
153-
154- with gr .Row ():
155- with gr .Box ():
156- with gr .Column ():
157- image_boxes = []
158- for i in range (self .nb_slider_items ):
159- visibility = True if i == 1 else False
160- t = gr .AnnotatedImage (
161- visible = visibility , elem_id = "model-2d"
162- ).style (
163- color_map = {self .class_name : "#ffae00" },
164- height = 512 ,
165- width = 512 ,
166- )
167- image_boxes .append (t )
168-
169- self .slider .input (
170- self .get_img_pred_pair , self .slider , image_boxes
171- )
172-
173- self .slider .render ()
174-
175- with gr .Box ():
176- self .volume_renderer .render ()
191+ self .setup_interface_outputs ()
192+
193+ # Define the signals/slots
194+ self .file_output .upload (self .upload_file , self .file_output , self .file_output )
195+ self .model_selector .input (fn = lambda x : self .set_class_name (x ), inputs = self .model_selector , outputs = None )
196+ self .run_btn .click (fn = self .process , inputs = [self .file_output , self .stripped_cb ],
197+ outputs = [self .volume_renderer , self .slider ]).then (fn = lambda :
198+ gr .DownloadButton (visible = True ), inputs = None , outputs = self .download_btn )
199+ self .download_btn .click (fn = self .package_results , inputs = [], outputs = self .download_file ).then (fn = lambda
200+ file_path : gr .File (label = "Download Zip" , visible = True , value = file_path ), inputs = self .download_file ,
201+ outputs = self .download_file )
177202
178203 # sharing app publicly -> share=True:
179204 # https://gradio.app/sharing-your-app/
0 commit comments