@@ -290,7 +290,7 @@ def __init__(self, image=None, logger=None):
290290 "learning_rate" : 1e-5 ,
291291 "weight_decay" : 0.1 ,
292292 "n_epochs" : 100 ,
293- "model_name" : "cpsam " + d .strftime ("_%Y%m%d_%H%M%S" ),
293+ "model_name" : "cp4 " + d .strftime ("_%Y%m%d_%H%M%S" ),
294294 }
295295
296296 self .stitch_threshold = 0.
@@ -496,21 +496,32 @@ def make_buttons(self):
496496 )
497497 self .useGPU .setFont (self .medfont )
498498 self .check_gpu ()
499- self .segBoxG .addWidget (self .useGPU , widget_row , 0 , 1 , 3 )
500-
501- # compute segmentation with general models
502- self .net_text = ["run CPSAM" ]
503- nett = ["cellpose super-generalist model" ]
504-
505- self .StyleButtons = []
506- jj = 4
507- for j in range (len (self .net_text )):
508- self .StyleButtons .append (
509- guiparts .ModelButton (self , self .net_text [j ], self .net_text [j ]))
510- w = 5
511- self .segBoxG .addWidget (self .StyleButtons [- 1 ], widget_row , jj , 1 , w )
512- jj += w
513- self .StyleButtons [- 1 ].setToolTip (nett [j ])
499+ self .segBoxG .addWidget (self .useGPU , widget_row , 0 , 1 , 3 )
500+
501+ self .progress = QProgressBar (self )
502+ self .segBoxG .addWidget (self .progress , widget_row , 3 , 1 , 5 )
503+
504+ # compute segmentation with built-in models
505+ widget_row += 1
506+ self .ModelChooseB = QComboBox ()
507+ self .ModelChooseB .setFont (self .medfont )
508+ current_index = 0
509+ self .ModelChooseB .addItems (models .MODEL_NAMES )
510+ self .ModelChooseB .setFixedWidth (175 )
511+ self .ModelChooseB .setCurrentIndex (current_index )
512+ tipstr = 'built-in models'
513+ self .ModelChooseB .setToolTip (tipstr )
514+ self .ModelChooseB .activated .connect (lambda : self .model_choose (custom = False ))
515+ self .segBoxG .addWidget (self .ModelChooseB , widget_row , 0 , 1 , 8 )
516+
517+ # compute segmentation w/ custom model
518+ self .ModelButtonB = QPushButton (u"run" )
519+ self .ModelButtonB .setFont (self .medfont )
520+ self .ModelButtonB .setFixedWidth (35 )
521+ self .ModelButtonB .clicked .connect (
522+ lambda : self .compute_segmentation (custom = False ))
523+ self .segBoxG .addWidget (self .ModelButtonB , widget_row , 8 , 1 , 1 )
524+ self .ModelButtonB .setEnabled (False )
514525
515526 widget_row += 1
516527 self .ncells = guiparts .ObservableVariable (0 )
@@ -521,10 +532,7 @@ def make_buttons(self):
521532 lambda n : self .roi_count .setText (f'{ str (n )} ROIs' )
522533 )
523534
524- self .segBoxG .addWidget (self .roi_count , widget_row , 0 , 1 , 4 )
525-
526- self .progress = QProgressBar (self )
527- self .segBoxG .addWidget (self .progress , widget_row , 4 , 1 , 5 )
535+ self .segBoxG .addWidget (self .roi_count , widget_row , 3 , 1 , 4 )
528536
529537 widget_row += 1
530538
@@ -786,15 +794,13 @@ def check_gpu(self, torch=True):
786794
787795
788796 def model_choose (self , custom = False ):
789- index = self .ModelChooseC .currentIndex (
790- ) if custom else self .ModelChooseB .currentIndex ()
791- if index > 0 :
792- if custom :
793- model_name = self .ModelChooseC .currentText ()
794- else :
795- model_name = self .net_names [index - 1 ]
796- print (f"GUI_INFO: selected model { model_name } , loading now" )
797- self .initialize_model (model_name = model_name , custom = custom )
797+ if custom :
798+ model_name = self .ModelChooseC .currentText ()
799+ else :
800+ model_name = self .ModelChooseB .currentText ()
801+ print (f"GUI_INFO: selected model { model_name } " )
802+ # avoid double-loading model unless we need to?
803+ # self.initialize_model(model_name=model_name, custom=custom)
798804
799805 def toggle_scale (self ):
800806 if self .scale_on :
@@ -805,11 +811,10 @@ def toggle_scale(self):
805811 self .scale_on = True
806812
807813 def enable_buttons (self ):
814+ self .ModelButtonB .setEnabled (True )
808815 if len (self .model_strings ) > 0 :
809816 self .ModelButtonC .setEnabled (True )
810- for i in range (len (self .StyleButtons )):
811- self .StyleButtons [i ].setEnabled (True )
812-
817+
813818 for i in range (len (self .FilterButtons )):
814819 self .FilterButtons [i ].setEnabled (True )
815820 if self .load_3D :
@@ -1889,8 +1894,8 @@ def get_model_path(self, custom=False):
18891894 self .current_model_path = os .fspath (
18901895 models .MODEL_DIR .joinpath (self .current_model ))
18911896 else :
1892- self .current_model = "cpsam"
1893- self .current_model_path = models .model_path (self .current_model )
1897+ self .current_model = self . ModelChooseB . currentText ()
1898+ self .current_model_path = models .cache_model_path (self .current_model )
18941899
18951900 def initialize_model (self , model_name = None , custom = False ):
18961901 if model_name is None or custom :
@@ -1907,7 +1912,7 @@ def initialize_model(self, model_name=None, custom=False):
19071912 models .MODEL_DIR .joinpath (self .current_model ))
19081913
19091914 self .model = models .CellposeModel (gpu = self .useGPU .isChecked (),
1910- pretrained_model = self .current_model )
1915+ pretrained_model = self .current_model_path )
19111916
19121917 def add_model (self ):
19131918 io ._add_model (self )
@@ -1926,7 +1931,8 @@ def new_model(self):
19261931 image_names = self .get_files ()[0 ]
19271932 self .train_data , self .train_labels , self .train_files , restore , normalize_params = io ._get_train_set (
19281933 image_names )
1929- TW = guiparts .TrainWindow (self , models .MODEL_NAMES )
1934+ self .training_params ["model_index" ] = self .ModelChooseB .currentIndex ()
1935+ TW = guiparts .TrainWindow (self )
19301936 train = TW .exec_ ()
19311937 if train :
19321938 self .logger .info (
@@ -1944,7 +1950,7 @@ def train_model(self, restore=None, normalize_params=None):
19441950 self .current_model = model_type
19451951
19461952 self .model = models .CellposeModel (gpu = self .useGPU .isChecked (),
1947- model_type = model_type )
1953+ pretrained_model = model_type )
19481954 save_path = os .path .dirname (self .filename )
19491955
19501956 print ("GUI_INFO: name of new model: " + self .training_params ["model_name" ])
@@ -2048,9 +2054,7 @@ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
20482054 print (normalize_params )
20492055 try :
20502056 masks , flows = self .model .eval (
2051- data ,
2052- diameter = diameter ,
2053- cellprob_threshold = cellprob_threshold ,
2057+ data , diameter = diameter , cellprob_threshold = cellprob_threshold ,
20542058 flow_threshold = flow_threshold , do_3D = do_3D , niter = niter ,
20552059 normalize = normalize_params , stitch_threshold = stitch_threshold ,
20562060 anisotropy = anisotropy , flow3D_smooth = flow3D_smooth ,
0 commit comments