diff --git a/cellacdc/apps.py b/cellacdc/apps.py index f8fca90b..dc4cd369 100755 --- a/cellacdc/apps.py +++ b/cellacdc/apps.py @@ -94,18 +94,13 @@ from . import io from . import cca_functions from . import path +from . import fonts POSITIVE_FLOAT_REGEX = float_regex(allow_negative=False) TREEWIDGET_STYLESHEET = _palettes.TreeWidgetStyleSheet() LISTWIDGET_STYLESHEET = _palettes.ListWidgetStyleSheet() BACKGROUND_RGBA = _palettes.get_disabled_colors()['Button'] -font = QFont() -font.setPixelSize(12) -italicFont = QFont() -italicFont.setPixelSize(12) -italicFont.setItalic(True) - class ArgWidget: def __init__(self, name, type, widget, defaultVal, valueSetter, valueGetter, changeSig=None): self.name = name @@ -838,7 +833,7 @@ def __init__( self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def addCentroidsSection(self, row, layout, **kwargs): sectionWidgets = [] @@ -1441,7 +1436,7 @@ def __init__(self, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def restoreState(self, state): self.appearanceGroupbox.restoreState(state) @@ -1588,7 +1583,7 @@ def __init__( layout.addLayout(buttonsLayout) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) if defaultEntry: self.updateFilename(defaultEntry) @@ -1856,7 +1851,7 @@ def __init__(self, basename='', parent=None): mainLayout.addLayout(buttonsLayout) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def createThirdSegmToggled(self, checked): self.appendTextWidget.setDisabled(not checked) @@ -3026,9 +3021,7 @@ def __init__( self.imageViewer = None super().__init__(parent) self.setWindowTitle(title) - font = QFont() - font.setPixelSize(12) - self.setFont(font) + self.setFont(fonts.font) mainLayout = QVBoxLayout() entriesLayout = QGridLayout() @@ -3922,7 +3915,7 @@ def __init__(self, parent=None): layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def showInfo(self): msg = widgets.myMessageBox(wrapText=False) @@ -4114,7 +4107,7 @@ def __init__( layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def selectFeatures(self): features = measurements.get_btrack_features() @@ -4340,7 +4333,7 @@ def __init__(self, posData=None, parent=None): layout.addLayout(buttonsLayout) layout.addStretch(1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def methodChanged(self, method): if method == 'mothermachine': @@ -4565,7 +4558,7 @@ def __init__( self.loop = None self.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self, checked=False): self.cancel = False @@ -4741,7 +4734,7 @@ def __init__(self, fileName, folderPath, readPatternFunc=None, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def segmFolderpathSelected(self, path): self.segmFolderPathEntry.setText(path) @@ -4941,7 +4934,7 @@ def __init__(self, parent=None, isSegm3D=True): cancelButton.clicked.connect(self.close) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) self.configPars = self.loadLastSelection() @@ -5102,7 +5095,7 @@ def __init__(self, df: pd.DataFrame, parent=None): ) self.setLayout(self.mainLayout) - self.setFont(font) + self.setFont(fonts.font) def saveSelection(self): saved_selections = io.get_saved_moth_bud_tot_selections() @@ -5507,7 +5500,7 @@ def __init__(self, df, parent=None): self.mainLayout.addLayout(buttonsLayout) self.setLayout(self.mainLayout) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): self.cancel = False @@ -5577,50 +5570,66 @@ def ok_cb(self): self.model_name = self.listBox.currentItem().text() self.close() - class QDialogSelectModel(QDialog): def __init__( - self, parent=None, addSkipSegmButton=False, customFirst='' + self, parent=None, addSkipSegmButton=False, customFirst='', + allowMultiSelection=False, lastSelection=None, + addSelectLastSelectionButton=False, + addSelectLastRecipeButton=False, + custom_title=None, + info_label='', ): self.cancel = True + self.loadLastRecipe = False super().__init__(parent) self.setWindowTitle('Select model') + self.info_label = info_label - mainLayout = QVBoxLayout() - topLayout = QVBoxLayout() - bottomLayout = QHBoxLayout() + self.allowMultiSelection = allowMultiSelection + self.lastSelection = [] + for m in (lastSelection or []): + if not isinstance(m, str): + continue + if m == 'thresholding': + m = 'Automatic thresholding' + self.lastSelection.append(m) + mainLayout = QVBoxLayout() self.mainLayout = mainLayout + title = custom_title or 'Select model to use for segmentation: ' + + titleContainer = QWidget(self) + titleLayout = QGridLayout(titleContainer) + titleLayout.setContentsMargins(0, 0, 0, 0) + titleLayout.setSpacing(0) label = QLabel(html_utils.paragraph( - 'Select model to use for segmentation: ' + title )) - # padding: top, left, bottom, right label.setStyleSheet("padding:0px 0px 3px 0px;") - topLayout.addWidget(label, alignment=Qt.AlignCenter) - - listBox = widgets.listWidget() - models = myutils.get_list_of_models() + titleLayout.addWidget(label, 0, 0, Qt.AlignCenter) + if info_label: + moreInfoButton = widgets.infoPushButton() + moreInfoButton.clicked.connect(self.showInfoLabel) + moreInfoButton.setSizePolicy( + QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed + ) + titleLayout.addWidget(moreInfoButton, 0, 0, Qt.AlignTop | Qt.AlignRight) + mainLayout.addWidget(titleContainer) - if customFirst: - try: - idx = models.index(customFirst) - models.insert(0, models.pop(idx)) - except ValueError: - print(f'Warning: {customFirst} not found in models list.') - pass + self.modelSelector = widgets.ModelSelectionWidget( + parent=self, + customFirst=customFirst, + allowMultiSelection=allowMultiSelection, + ) + # Convenience aliases kept for backward compatibility + self.listBox = self.modelSelector.listBox + mainLayout.addWidget(self.modelSelector) - listBox.setFont(font) - listBox.addItems(models) - addCustomModelItem = QListWidgetItem('Add custom model...') - addCustomModelItem.setFont(italicFont) - listBox.addItem(addCustomModelItem) - listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) - listBox.setCurrentRow(0) - self.listBox = listBox - listBox.itemDoubleClicked.connect(self.ok_cb) - topLayout.addWidget(listBox) + if not allowMultiSelection: + self.listBox.itemDoubleClicked.connect(self.ok_cb) + bottomLayout = QHBoxLayout() cancelButton = widgets.cancelPushButton('Cancel') okButton = widgets.okPushButton(' Ok ') okButton.setShortcut(Qt.Key_Enter) @@ -5632,53 +5641,178 @@ def __init__( skipSegmButton = widgets.SkipPushButton('Skip segmentation') bottomLayout.addWidget(skipSegmButton) skipSegmButton.clicked.connect(self.skipSegm) + if addSelectLastSelectionButton and allowMultiSelection: + selectLastSelButton = widgets.reloadPushButton('Load last selection...') + selectLastSelButton.clicked.connect(self.selectLastSelection) + selectLastSelButton.setEnabled(bool(self.lastSelection)) + bottomLayout.addWidget(selectLastSelButton) + if addSelectLastRecipeButton and allowMultiSelection: + selectLastRecipeButton = widgets.reloadPushButton('Load last recipe...') + selectLastRecipeButton.clicked.connect(self.selectLastRecipe) + selectLastRecipeButton.setEnabled(bool(self.lastSelection)) + bottomLayout.addWidget(selectLastRecipeButton) + if allowMultiSelection: + addCustomModelButton = widgets.addPushButton('Add custom model...') + addCustomModelButton.clicked.connect(self.addCustomModel) + bottomLayout.addWidget(addCustomModelButton) bottomLayout.addWidget(okButton) bottomLayout.setContentsMargins(0, 10, 0, 0) - mainLayout.addLayout(topLayout) mainLayout.addLayout(bottomLayout) self.setLayout(mainLayout) - # Connect events okButton.clicked.connect(self.ok_cb) cancelButton.clicked.connect(self.cancel_cb) - self.setStyleSheet(LISTWIDGET_STYLESHEET) - + + @property + def selectionSequence(self): + return self.modelSelector.selectionSequence + + @property + def modelItemsMap(self): + return self.modelSelector.modelItemsMap + def skipSegm(self): self.cancel = False self.selectedModel = 'skip_segmentation' self.close() - + + def selectLastSelection(self): + if not self.lastSelection: + return + self.modelSelector.setSelectionFromList(self.lastSelection) + + def selectLastRecipe(self): + if not self.lastSelection: + return + self.selectLastSelection() + self.cancel = False + self.loadLastRecipe = True + self.selectedModel = self.lastSelection.copy() + self.close() + + def _runAddCustomModelWorkflow(self): + modelFilePath = addCustomModelMessages(self) + if modelFilePath is None: + return None + + myutils.store_custom_model_path(modelFilePath) + modelName = os.path.basename(os.path.dirname(modelFilePath)) + self.modelSelector.registerCustomModel(modelName) + return modelName + + def addCustomModel(self): + modelName = self._runAddCustomModelWorkflow() + if modelName is None: + return + + if self.allowMultiSelection: + self.modelSelector.addModelSelection(modelName) + else: + item = QListWidgetItem(modelName) + self.listBox.addItem(item) + self.listBox.setCurrentItem(item) + + def showInfoLabel(self): + if not self.info_label: + return + msg = widgets.myMessageBox(showCentered=False, wrapText=False) + txt = html_utils.paragraph(self.info_label) + msg.information(self, 'More info', txt) + def keyPressEvent(self, event: QKeyEvent) -> None: if event.key() == Qt.Key_Escape: event.ignore() return - super().keyPressEvent(event) - def ok_cb(self, event): + def askSelectedModelsOrder(self, selected_models): + dialog = QDialog(self) + dialog.setWindowTitle('Order selected models') + dialog.setWindowFlags(Qt.Window | Qt.WindowStaysOnTopHint) + + layout = QVBoxLayout(dialog) + infoTxt = html_utils.paragraph( + 'Drag and drop to change the order of the selected models.
' + 'The top model will run first.' + ) + layout.addWidget(QLabel(infoTxt)) + + modelOrderView = widgets.ReorderableListView( + selected_models, parent=dialog, isSingleSelection=True + ) + layout.addWidget(modelOrderView) + + buttonsLayout = QHBoxLayout() + cancelButton = widgets.cancelPushButton('Cancel') + okButton = widgets.okPushButton('Ok') + buttonsLayout.addStretch(1) + buttonsLayout.addWidget(cancelButton) + buttonsLayout.addSpacing(20) + buttonsLayout.addWidget(okButton) + layout.addLayout(buttonsLayout) + + cancelButton.clicked.connect(dialog.reject) + okButton.clicked.connect(dialog.accept) + + if dialog.exec_() != QDialog.Accepted: + return None + + return modelOrderView.items() + + def ok_cb(self, event=None): self.clickedButton = self.sender() - self.cancel = False - item = self.listBox.currentItem() + + if self.allowMultiSelection: + if not self.selectionSequence: + return + + selected_models = list(self.selectionSequence) + if len(selected_models) > 1: + ordered_models = self.askSelectedModelsOrder(selected_models) + if ordered_models is None: + return + selected_models = ordered_models + + self.selectedModel = selected_models + self.cancel = False + self.close() + return + + selected_items = self.listBox.selectedItems() + if not selected_items: + return + + selected_models = [item.text() for item in selected_items] + if len(selected_models) > 1: + ordered_models = self.askSelectedModelsOrder(selected_models) + if ordered_models is None: + return + self.selectedModel = ordered_models + self.cancel = False + self.close() + return + + item = selected_items[0] model = item.text() if model == 'Add custom model...': - modelFilePath = addCustomModelMessages(self) - if modelFilePath is None: + modelName = self._runAddCustomModelWorkflow() + if modelName is None: return - myutils.store_custom_model_path(modelFilePath) - modelName = os.path.basename(os.path.dirname(modelFilePath)) item = QListWidgetItem(modelName) self.listBox.addItem(item) self.listBox.setCurrentItem(item) elif model == 'Automatic thresholding': - self.selectedModel = 'thresholding' + self.selectedModel = model + self.cancel = False self.close() else: self.selectedModel = model + self.cancel = False self.close() - def cancel_cb(self, event): + def cancel_cb(self, event=None): self.cancel = True self.selectedModel = None self.close() @@ -5725,7 +5859,7 @@ def __init__(self, text, parent=None): mainLayout.addLayout(buttonsLayout) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) class startStopFramesDialog(QBaseDialog): def __init__( @@ -5760,7 +5894,7 @@ def __init__( okButton.clicked.connect(self.ok_cb) cancelButton.clicked.connect(self.close) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): if self.selectFramesGroupbox.warningLabel.text(): @@ -6186,7 +6320,8 @@ def __init__( self.addAdditionalValues(additionalValues) self.setLayout(mainLayout) - self.setFont(font) + if font is not None: + self.setFont(font) # self.setModal(True) def showWhySizeTisGrayed(self): @@ -6723,9 +6858,7 @@ def __init__(self, mainWindow): seeHereLabel.setTextFormat(Qt.RichText) seeHereLabel.setTextInteractionFlags(Qt.TextBrowserInteraction) seeHereLabel.setOpenExternalLinks(True) - font = QFont() - font.setPixelSize(12) - seeHereLabel.setFont(font) + seeHereLabel.setFont(fonts.font) seeHereLabel.setStyleSheet("padding:12px 0px 0px 0px;") paramsLayout.addWidget(seeHereLabel, row, 0, 1, 2) @@ -7112,7 +7245,7 @@ def __init__( layout.addLayout(buttonsLayout, 2, 1) self.setLayout(layout) - self.setFont(font) + self.setFont(fonts.font) def copyErrorMessage(self): cb = QApplication.clipboard() @@ -8029,7 +8162,7 @@ def addAlphaScrollbar(self, channelName, imageItem, alphaScrollBar=None): if alphaScrollBar is None: alphaScrollBar = QScrollBar(Qt.Horizontal) label = QLabel(f'Alpha {channelName}') - label.setFont(font) + label.setFont(fonts.font) label.hide() alphaScrollBar.imageItem = imageItem alphaScrollBar.label = label @@ -8584,7 +8717,7 @@ def __init__(self, expPaths: dict, infoPaths: dict=None, parent=None): QAbstractItemView.SelectionMode.ExtendedSelection ) self.treeWidget.setHeaderHidden(True) - self.treeWidget.setFont(font) + self.treeWidget.setFont(fonts.font) for exp_path, positions in expPaths.items(): pathLevels = exp_path.split(os.sep) posFoldersInfo = None @@ -9521,7 +9654,7 @@ def __init__( entryWidget.setText(defaultTxt) if not self.allowText: entryWidget.textChanged[str].connect(self.onTextChanged) - entryWidget.setFont(font) + entryWidget.setFont(fonts.font) entryWidget.setAlignment(Qt.AlignCenter) self.entryWidget = entryWidget @@ -9529,7 +9662,7 @@ def __init__( if allowedValues is not None: notValidLabel = QLabel() notValidLabel.setStyleSheet('color: red') - notValidLabel.setFont(font) + notValidLabel.setFont(fonts.font) notValidLabel.setAlignment(Qt.AlignCenter) self.notValidLabel = notValidLabel @@ -10055,7 +10188,7 @@ def __init__( listBox.addItems(items) listBox.setSelectionMode(QAbstractItemView.SelectionMode.ExtendedSelection) listBox.setCurrentRow(0) - listBox.setFont(font) + listBox.setFont(fonts.font) topLayout.addWidget(listBox) listBox.hide() self.ListBox = listBox @@ -10077,7 +10210,7 @@ def __init__( if showInFileManagerPath is not None: showInFileManagerButton.clicked.connect(self.showInFileManager) - self.setFont(font) + self.setFont(fonts.font) def setSelectedItems(self, selectedItemsText): if self.multiPosButton.isChecked(): @@ -10975,9 +11108,7 @@ def __init__(self, filename, SizeZ, filenamesWithInfo, parent=None): self.setLayout(mainLayout) - font = QFont() - font.setPixelSize(12) - self.setFont(font) + self.setFont(fonts.font) # self.setModal(True) @@ -11838,7 +11969,7 @@ def __init__( printl(traceback.format_exc()) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) # self.setModal(True) def warningNoSegmRecipes(self): @@ -13045,7 +13176,7 @@ def __init__( metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) self.metricsTreeWidget = metricsTreeWidget for chName in allChNames: @@ -13236,7 +13367,7 @@ def __init__( testButton.clicked.connect(self.test_cb) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -13497,7 +13628,7 @@ def __init__(self, posDatas, parent=None): _spinBox = QSpinBox() _spinBox.setMaximum(214748364) _spinBox.setAlignment(Qt.AlignCenter) - _spinBox.setFont(font) + _spinBox.setFont(fonts.font) if posData.acdc_df is not None: _val = posData.acdc_df.index.get_level_values(0).max()+1 else: @@ -13616,7 +13747,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False): for i, (acdc_df_endname, acdc_df) in enumerate(acdcDfs.items()): metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) classified_metrics = measurements.classify_acdc_df_colnames( acdc_df, allChNames @@ -13804,7 +13935,7 @@ def __init__(self, acdcDfs, allChNames, parent=None, debug=False): # self.newColNameLineEdit.editingFinished.connect(self.equationChanged) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -13992,7 +14123,7 @@ def __init__( row += 1 self.equationsList = widgets.TreeWidget() - self.equationsList.setFont(font) + self.equationsList.setFont(fonts.font) self.equationsList.setHeaderLabels(['Metric', 'Expression']) self.equationsList.setSelectionMode( QAbstractItemView.SelectionMode.ExtendedSelection) @@ -14301,7 +14432,7 @@ def __init__( mainLayout.addSpacing(20) mainLayout.addLayout(buttonsLayout) - self.setFont(font) + self.setFont(fonts.font) self.setLayout(mainLayout) def checkDuplicateShortcuts(self, text): @@ -14428,7 +14559,7 @@ def __init__(self, posData, parent=None): self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def ok_cb(self): self.cancel = False @@ -14644,7 +14775,7 @@ def __init__( self.setLayout(self._layout) - # self.setFont(font) + # self.setFont(fonts.font) self.addButton.clicked.connect(self.addFeatureField) @@ -14927,7 +15058,7 @@ def __init__( mainLayout.addStretch() self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.unitCombobox.currentTextChanged.connect(self.updateLengthUnit) self.colorButton.clicked.disconnect() @@ -15046,7 +15177,7 @@ def __init__( self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) def _warnNonUniqueCategories(self, category_1, category_2): txt = html_utils.paragraph(f""" @@ -15107,7 +15238,7 @@ def __init__( metricsTreeWidget = QTreeWidget() metricsTreeWidget.setHeaderHidden(True) - metricsTreeWidget.setFont(font) + metricsTreeWidget.setFont(fonts.font) self.metricsTreeWidget = metricsTreeWidget for groupName, features in features_groups.items(): @@ -15152,7 +15283,7 @@ def __init__( metricsTreeWidget.itemDoubleClicked.connect(self.addFeatureName) self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.setStyleSheet(TREEWIDGET_STYLESHEET) @@ -15438,7 +15569,7 @@ def __init__(self, parent=None, title='Input'): self.buttonsLayout = buttonsLayout - self.setFont(font) + self.setFont(fonts.font) self.setLayout(self.mainLayout) def askText(self, prompt, infoText='', allowEmpty=False): @@ -15908,7 +16039,7 @@ def __init__(self, parent=None, **properties): mainLayout.addStretch() self.setLayout(mainLayout) - self.setFont(font) + self.setFont(fonts.font) self.colorButton.clicked.disconnect() self.colorButton.clicked.connect(self.selectColor) @@ -19312,7 +19443,7 @@ def __init__( self.setAcceptDrops(True) - self.setFont(font) + self.setFont(fonts.font) def dragEnterEvent(self, event): event.acceptProposedAction() @@ -19352,12 +19483,59 @@ def expFolderToPosFoldernamesMapper(self): return expPathsPosFoldernamesMapper def ok_cb(self): - self.cancel = False + #verify all selected folders have Images folder: + faultyFolders = [] + for path, selected_pos in self.expFolderToPosFoldernamesMapper().items(): + if selected_pos == ['']: + images_path = myutils.get_images_folderpath(path) + if images_path is None or not os.path.exists(images_path): + faultyFolders.append(path) + + else: + for pos in selected_pos: + pos_path = os.path.join(path, pos) + images_path = myutils.get_images_folderpath(pos_path) + if images_path is None or not os.path.exists(images_path): + faultyFolders.append(pos_path) + + if faultyFolders: + self.warnNoAllValid(faultyFolders) + return + self.paths = self.pathsList() self.selectedExpFolderToPosFoldernamesMapper = ( self.expFolderToPosFoldernamesMapper() ) + if not self.selectedExpFolderToPosFoldernamesMapper: + self.warnEmptySelection() + return + self.cancel = False + self.close() + + def warnNoAllValid(self, faultyFolders=None): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(f""" + Some of the selected folders (see below) do not contain an Images folder.

+ Please, make sure to select Position folders, the Images folder inside Position folders, or any folder containing Position folders as sub-directories.

+ Thank you for your patience!

+ Selected folders: + + """) + msg.warning( + self, 'Some folders are not valid', txt + ) + + def warnEmptySelection(self): + msg = widgets.myMessageBox(wrapText=False) + txt = html_utils.paragraph(""" + No folder was selected.

+ """) + msg.warning( + self, 'No folder selected', txt + ) def warnNoValidPathsFound(self, selected_path): msg = widgets.myMessageBox(wrapText=False) @@ -19426,11 +19604,11 @@ def addFolderPath(self, selected_path): myutils.addToRecentPaths(selected_path) folder_type = myutils.determine_folder_type(selected_path) - is_pos_folder, is_images_folder, folder_path = folder_type + is_pos_folder, is_images_folder, folder_path = folder_type if is_pos_folder: paths = [selected_path] elif is_images_folder: - paths = [os.path.dirname(selected_path)] + paths = [os.path.dirname(selected_path) if selected_path.endswith('Images') else selected_path] elif self.scanTree: print(f'Scanning selected folder "{selected_path}"...') exp_paths = path.get_posfolderpaths_walk(selected_path) diff --git a/cellacdc/core.py b/cellacdc/core.py index 110de236..d34733dc 100755 --- a/cellacdc/core.py +++ b/cellacdc/core.py @@ -3285,7 +3285,7 @@ def check_file_time_proximity(file1, file2, max_seconds=300, logger_func=print): logger_func(f'Warning: The files "{file1}" and "{file2}" were not saved within {max_seconds} seconds of each other.') return False -def verify_acdc_df_segm(posData: load.loadData, logger_func=print): +def verify_acdc_df_segm(posData: 'load.loadData', logger_func=print): if posData.segmMetadata is None: return None segm_info = posData.segmMetadata[os.path.basename(posData.segm_npz_path)] @@ -3304,7 +3304,7 @@ def verify_acdc_df_segm(posData: load.loadData, logger_func=print): return csv_filepath -def verify_add_data_segm_proximity(posData: load.loadData, logger_func=print): +def verify_add_data_segm_proximity(posData: 'load.loadData', logger_func=print): segm_path = posData.segm_npz_path segm_filename = os.path.basename(segm_path).replace('.npz', '') add_data_folder = os.path.join(posData.images_path, segm_filename) @@ -3336,7 +3336,7 @@ def verify_add_data_segm_proximity(posData: load.loadData, logger_func=print): # Total time spend optimising here # >5 hrs # please update this if you try to optimize again -def count_objects_and_init_rps(posData: load.loadData, logger_func=print): +def count_objects_and_init_rps(posData: 'load.loadData', logger_func=print): allIDs = set() segm_data = posData.segm_data diff --git a/cellacdc/dataPrep.py b/cellacdc/dataPrep.py index 10b55844..9856e7f4 100755 --- a/cellacdc/dataPrep.py +++ b/cellacdc/dataPrep.py @@ -44,6 +44,7 @@ from . import urls from . import io from .help import about +from . import fonts if os.name == 'nt': try: @@ -373,7 +374,7 @@ def gui_createToolBars(self): navigateToolbar.addAction(self.interpAction) self.ROIshapeComboBox = QComboBox() - self.ROIshapeComboBox.setFont(apps.font) + self.ROIshapeComboBox.setFont(fonts.font) self.ROIshapeComboBox.SizeAdjustPolicy(QComboBox.SizeAdjustPolicy.AdjustToContents) self.ROIshapeComboBox.addItems([' 256x256 ']) ROIshapeLabel = QLabel(html_utils.paragraph( diff --git a/cellacdc/fonts.py b/cellacdc/fonts.py new file mode 100644 index 00000000..31219da5 --- /dev/null +++ b/cellacdc/fonts.py @@ -0,0 +1,14 @@ +from . import GUI_INSTALLED + +if GUI_INSTALLED: + from qtpy.QtGui import QFont + + font = QFont() + font.setPixelSize(12) + italicFont = QFont() + italicFont.setPixelSize(12) + italicFont.setItalic(True) + +else: + font = None + italicFont = None \ No newline at end of file diff --git a/cellacdc/gui.py b/cellacdc/gui.py index 443c8fee..72017406 100755 --- a/cellacdc/gui.py +++ b/cellacdc/gui.py @@ -543,7 +543,7 @@ def initGlobalAttr(self): ] self.lin_tree_df_colnames = self.lin_tree_df_int_cols + self.lin_tree_df_bool_col + self.lin_tree_col_checks - self.SegForLostIDsSettings = {} + self.SegForLostIDsSettings = {} def setWindowIcon(self, icon=None): if icon is None: @@ -8452,103 +8452,325 @@ def gui_addCreatedAxesItems(self): self.ax1.exportMaskImageItem = self.exportMaskImageItem def SegForLostIDsSetSettings(self): + posData = self.data[self.pos_i] + displayed_input_label = 'Displayed image' + + recipe_json_path = os.path.join( + settings_folderpath, 'segmentation_for_lostIDs_recipe.json' + ) try: - prev_model = str(self.df_settings.at['SegForLostIDsModel', 'value']) + prev_models = [ + model.strip() for model in str( + self.df_settings.at['SegForLostIDsModel', 'value'] + ).split(',') if model.strip() + ] except KeyError: - prev_model = None - win = apps.QDialogSelectModel(parent=self, customFirst=prev_model) + prev_models = [] + + has_last_recipe = bool(prev_models) and os.path.exists(recipe_json_path) + seg_for_lost_ids_info = ( + 'Segmentation for lost IDs settings

' + 'Use this dialog to define the segmentation workflow used for ' + 'resegmenting local neighborhood lost IDs. Other already segmented cells are filled ' + 'with background, which makes even dimm cells seem bright after ' + 'rescaling before resegmentation. This is especially usefull for ' + 'cells which have varying intensities over time, like FUCCI cells.

' + 'How model selection works
' + '- You can select one model or multiple models.
' + '- In multi-selection mode you can include the same model multiple ' + 'times (for example, model A, then model B, then model A again).
' + '- After confirming, you can reorder the selected models. The order ' + 'is the execution order. ' + '- You then will be asked to set model parameters in the order selected.

' + ' - Pay special attention to the additional "Settings for local ' + 'segmentation" section, here you can for example select any image as input.

' + 'Load last selection...
' + 'Restores only the list of selected model names (the recipe order ' + 'selection), then lets you continue configuring parameters.

' + 'Load last recipe...
' + 'Loads the complete saved recipe from disk, including model order and ' + 'all model-specific settings (when available).

' + 'Add custom model...
' + 'Lets you register an additional local custom model and include it in ' + 'the sequence.

' + 'Tip: if you want to run the same model twice with different ' + 'parameters, add it twice and configure each step independently.' + ) + win = apps.QDialogSelectModel( + parent=self, + allowMultiSelection=True, + lastSelection=prev_models, + addSelectLastSelectionButton=bool(prev_models), + addSelectLastRecipeButton=has_last_recipe, + custom_title='Select model(s) for segmentation of lost IDs', + info_label=seg_for_lost_ids_info, + ) win.exec_() if win.cancel: self.logger.info('Seg for lost IDs cancelled.') return - base_model_name = win.selectedModel - if base_model_name: - self.df_settings.at['SegForLostIDsModel', 'value'] = base_model_name + if getattr(win, 'loadLastRecipe', False): + self.logger.info('Loading last segmentation recipe for lost IDs...') + try: + with open(recipe_json_path, 'r') as f: + recipe_data = json.load(f) + model_settings = [] + for entry in recipe_data['models']: + model_settings.append({ + 'win': None, + 'init_kwargs_new': entry['init_kwargs_new'], + 'args_new': entry['args_new'], + 'base_model_name': entry['base_model_name'], + 'init_kwargs': entry.get('init_kwargs', {}), + 'model_kwargs': entry.get('model_kwargs', {}), + 'preproc_recipe': entry.get('preproc_recipe', None), + 'applyPostProcessing': entry.get('applyPostProcessing', False), + 'standardPostProcessKwargs': entry.get('standardPostProcessKwargs', {}), + 'customPostProcessFeatures': entry.get('customPostProcessFeatures', None), + 'customPostProcessGroupedFeatures': entry.get('customPostProcessGroupedFeatures', None), + }) + self.SegForLostIDsSettings = {'models_settings': model_settings} + # Restore model names in settings + restored_models = [ + 'Automatic thresholding' + if m['base_model_name'] == 'thresholding' + else m['base_model_name'] + for m in model_settings + ] + self.df_settings.at['SegForLostIDsModel', 'value'] = ( + ', '.join(restored_models) + ) + self.df_settings.to_csv(self.settings_csv_path) + self.logger.info('Last segmentation recipe loaded successfully.') + except Exception as e: + self.logger.error(f'Failed to load last recipe: {e}') + return + + selected_models = win.selectedModel + if isinstance(selected_models, str): + selected_models = [selected_models] + + if not selected_models: + self.logger.info('Seg for lost IDs cancelled.') + return + + if selected_models: + self.df_settings.at['SegForLostIDsModel', 'value'] = ( + ', '.join(selected_models) + ) self.df_settings.to_csv(self.settings_csv_path) - model_name = 'local_seg' + all_extra_params = [ + 'image_channel_name', + 'overlap_threshold', + 'padding', + 'size_perc_diff', + 'distance_filler_growth', + 'allow_only_tracked_cells', + ] + extra_types = { + 'overlap_threshold': float, + 'padding': float, + 'size_perc_diff': float, + 'distance_filler_growth': float, + 'allow_only_tracked_cells': bool, + 'image_channel_name': str, + } + extra_defaults = { + 'overlap_threshold': 0.5, + 'padding': 0.8, + 'size_perc_diff': 0.3, + 'distance_filler_growth': 1., + 'allow_only_tracked_cells': False, + 'image_channel_name': displayed_input_label, + } + extra_desc = { + 'overlap_threshold': ( + 'Overlap threshold with other already segemented cells ' + 'over which newly segmented cells are discarded' + ), + 'padding': ( + 'Padding of the box used for new segmentation around the ' + 'segmentation from the previous frame' + ), + 'size_perc_diff': ( + 'Relative size difference acceptable compared to previous ' + 'frames' + ), + 'distance_filler_growth': ( + 'Cells which are already segmented are filled with random ' + 'noise sampled from background to ensure that they do not ' + 'get segmented again. This parameter controls the additional ' + 'padding around the already segmented cells.' + ), + 'allow_only_tracked_cells': ( + 'If no new cell IDs should be permitted ' + '(based on real time tracking)' + ), + 'image_channel_name': ( + 'Image channel used as model input. ' + 'Select "Displayed image" to use exactly what is currently ' + 'shown in the viewer, or select a specific fluorescence ' + 'channel.' + ), + } - idx = self.modelNames.index(model_name) - acdcSegment = self.acdcSegment_li[idx] + model_settings = [] + remembered_extra_args = {} + for model_idx, selected_model_name in enumerate(selected_models): + model_name = selected_model_name + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + try: + if selected_model_name in self.modelNames: + idx = self.modelNames.index(selected_model_name) + acdcSegment = self.acdcSegment_li[idx] + if acdcSegment is None: + self.logger.info(f'Importing {model_name}...') + acdcSegment = myutils.import_segment_module(model_name) + self.acdcSegment_li[idx] = acdcSegment + else: + self.logger.info(f'Importing {model_name}...') + acdcSegment = myutils.import_segment_module(model_name) + except (ImportError, KeyError) as e: + self.logger.error(f'Error importing {model_name}: {e}') + return - try: - if acdcSegment is None or base_model_name != self.local_seg_base_model_name: - self.logger.info(f'Importing {base_model_name}...') - acdcSegment = myutils.import_segment_module(base_model_name) - self.acdcSegment_li[idx] = acdcSegment - self.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: - self.logger.error(f'Error importing {base_model_name}: {e}') - return - - extra_params = ['overlap_threshold', - 'padding', - 'size_perc_diff', - 'distance_filler_growth', - 'max_iterations', - 'allow_only_tracked_cells'] - - extra_types = [float, float, float, float, int, bool] - - extra_defaults = [0.5, 0.8, 0.3, 1., 2, False] - - extra_desc = ['Overlap threshold with other already segemented cells over which newly segmented cells are discarded', - 'Padding of the box used for new segmentation around the segmentation from the previous frame', - 'Relative size difference acceptable compared to previous frames', - """Cells which are already segmented are filled with random noise sampled from background - to ensure that they don't get segmented again. - This parameter controls the additional padding around the already segmented cells.""", - """The algorithm will try and segment the maximum amount - of cells in the image by running the model several - times and filling new found cells with background noise. - How many of these iterations should be run?""", - "If no new cell IDs should be permitted (based on real time tracking)"] - - extra_ArgSpec = [] - for i, param in enumerate(extra_params): - param = ArgSpec(name=param, - default=extra_defaults[i], - type=extra_types[i], - desc=extra_desc[i], - docstring='') - - extra_ArgSpec.append(param) + extra_params = all_extra_params - init_params, segment_params = myutils.getModelArgSpec(acdcSegment) - segment_params = [arg for arg in segment_params if arg[0] != 'diameter'] - - extraParamsTitle = 'Settings for local segmentation' - win = self.initSegmModelParams( - base_model_name, acdcSegment, init_params, segment_params, - extraParams=extra_ArgSpec, extraParamsTitle=extraParamsTitle, - initLastParams=True, ini_filename='segmentation_for_lostIDs.ini', - ) + available_fluo_channels = [ + ch for ch in posData.chNames if ch != self.user_ch_name + ] + channel_options = [displayed_input_label, *available_fluo_channels] - if win is None: - self.logger.info('Segmentation for lost IDs cancelled.') - return + class _SegForLostIDsInputChannelType: + values = channel_options - init_kwargs_new = {} - args_new = {} - for key, val in win.init_kwargs.items(): - if key in extra_params: - args_new[key] = val - else: - init_kwargs_new[key] = val + extra_types['image_channel_name'] = _SegForLostIDsInputChannelType + + extra_ArgSpec = [] + for param in extra_params: + param_arg = ArgSpec( + name=param, + default=extra_defaults[param], + type=extra_types[param], + desc=extra_desc[param], + docstring='' + ) + extra_ArgSpec.append(param_arg) - for key, val in win.extra_kwargs.items(): - if key in extra_params: - args_new[key] = val + init_params, segment_params = myutils.getModelArgSpec(acdcSegment) + segment_params = [ + arg for arg in segment_params if arg[0] != 'diameter' + ] + + initLastParams = True + if model_name == 'thresholding': + win_thresh = apps.QDialogAutomaticThresholding( + parent=self, isSegm3D=self.isSegm3D + ) + win_thresh.exec_() + if win_thresh.cancel: + self.logger.info('Segmentation for lost IDs cancelled.') + return + self.model_kwargs = win_thresh.segment_kwargs + thresh_method = self.model_kwargs['threshold_method'] + gauss_sigma = self.model_kwargs['gauss_sigma'] + segment_params = myutils.insertModelArgSpec( + segment_params, 'threshold_method', thresh_method + ) + segment_params = myutils.insertModelArgSpec( + segment_params, 'gauss_sigma', gauss_sigma + ) + initLastParams = False + + extraParamsTitle = ( + f'Settings for local segmentation ' + f'({model_idx + 1}/{len(selected_models)})' + ) + win = self.initSegmModelParams( + model_name, acdcSegment, init_params, segment_params, + extraParams=extra_ArgSpec, + extraParamsTitle=extraParamsTitle, + initLastParams=initLastParams, + ini_filename='segmentation_for_lostIDs.ini', + ) + + if win is None: + self.logger.info('Segmentation for lost IDs cancelled.') + return + + init_kwargs_new = {} + args_new = {} + for key, val in win.init_kwargs.items(): + if key in extra_params: + args_new[key] = val + else: + init_kwargs_new[key] = val + + for key, val in win.extra_kwargs.items(): + if key in extra_params: + if key == 'image_channel_name': + init_kwargs_new[key] = val + else: + args_new[key] = val + + for key, val in remembered_extra_args.items(): + if key == 'image_channel_name': + init_kwargs_new.setdefault(key, val) + continue + args_new.setdefault(key, val) + + if model_idx == 0: + remembered_extra_args = args_new.copy() + remembered_extra_args['image_channel_name'] = ( + init_kwargs_new.get('image_channel_name', displayed_input_label) + ) + + model_settings.append({ + 'win': win, + 'init_kwargs_new': init_kwargs_new, + 'args_new': args_new, + 'base_model_name': model_name, + 'init_kwargs': dict(win.init_kwargs), + 'model_kwargs': dict(win.model_kwargs), + 'preproc_recipe': win.preproc_recipe, + 'applyPostProcessing': win.applyPostProcessing, + 'standardPostProcessKwargs': win.standardPostProcessKwargs, + 'customPostProcessFeatures': win.customPostProcessFeatures, + 'customPostProcessGroupedFeatures': win.customPostProcessGroupedFeatures, + }) self.SegForLostIDsSettings = { - 'win': win, - 'init_kwargs_new': init_kwargs_new, - 'args_new': args_new, - 'base_model_name': base_model_name, + 'models_settings': model_settings, } + # Persist recipe to disk so it survives across sessions + try: + recipe_data = { + 'models': [ + { + 'base_model_name': ms['base_model_name'], + 'init_kwargs_new': ms['init_kwargs_new'], + 'args_new': ms['args_new'], + 'init_kwargs': ms['init_kwargs'], + 'model_kwargs': ms['model_kwargs'], + 'preproc_recipe': ms['preproc_recipe'], + 'applyPostProcessing': ms['applyPostProcessing'], + 'standardPostProcessKwargs': ms['standardPostProcessKwargs'], + 'customPostProcessFeatures': ms['customPostProcessFeatures'], + 'customPostProcessGroupedFeatures': ms['customPostProcessGroupedFeatures'], + } + for ms in model_settings + ] + } + with open(recipe_json_path, 'w') as f: + json.dump(recipe_data, f, indent=2, default=str) + except Exception as e: + self.logger.warning(f'Could not save recipe to disk: {e}') + def segForLostIDsButtonClicked(self): self.setFrameNavigationDisabled(disable=True, why='Segmentation for lost IDs') @@ -8572,9 +8794,9 @@ def onSegForLostInit(self): self.SegForLostIDsSetSettings() self.SegForLostIDsWaitCond.wakeAll() - def SegForLostIDsWorkerAskInstallModel(self, model_name): - myutils.check_install_package(model_name) - self.SegForLostIDsWaitCond.wakeAll() + # def SegForLostIDsWorkerAskInstallModel(self, model_name): + # myutils.check_install_package(model_name) + # self.SegForLostIDsWaitCond.wakeAll() def startSegForLostIDsWorker(self): self.SegForLostIDsMutex = QMutex() @@ -8588,9 +8810,9 @@ def startSegForLostIDsWorker(self): # Connect the worker's signal to the main thread's slot self.SegForLostIDsWorker.sigAskInit.connect(self.onSegForLostInit) - self.SegForLostIDsWorker.sigAskInstallModel.connect( - self.SegForLostIDsWorkerAskInstallModel - ) + # self.SegForLostIDsWorker.sigAskInstallModel.connect( + # self.SegForLostIDsWorkerAskInstallModel + # ) self.SegForLostIDsWorker.sigshowImageDebug.connect( self.showImageDebug ) @@ -8603,6 +8825,9 @@ def startSegForLostIDsWorker(self): self.onSigStoreDataSegForLostIDsWorker) self.SegForLostIDsWorker.sigUpdateRP.connect( self.onSigUpdateRPSegForLostIDsWorker) + self.SegForLostIDsWorker.sigGetSegForLostIDsInputImg.connect( + self.onSigGetInputImgSegForLostIDsWorker + ) # self.SegForLostIDsWorker.sigGetData.connect(self.onSigGetDataSegForLostIDsWorker) # self.SegForLostIDsWorker.sigGet2Dlab.connect(self.onSigGet2DlabSegForLostIDsWorker) # self.SegForLostIDsWorker.sigGetTrackedLostIDs.connect(self.onSigGetTrackedSegForLostIDsWorker) @@ -8667,6 +8892,31 @@ def onSigTrackManuallyAddedObjectSegForLostIDsWorker(self, added_IDs, isNewID, w self.trackManuallyAddedObject(added_IDs, isNewID, wl_update=wl_update, wl_track_og_curr=wl_track_og_curr) self.SegForLostIDsWaitCond.wakeAll() + def onSigGetInputImgSegForLostIDsWorker(self, image_channel_name): + displayed_input_label = 'Displayed image' + posData = self.data[self.pos_i] + + if ( + not image_channel_name + or image_channel_name == displayed_input_label + ): + img = self.getDisplayedImg1() + self.SegForLostIDsWorker.inputImgForSegForLostIDs = img + self.SegForLostIDsWaitCond.wakeAll() + return + + self.getChData(requ_ch={image_channel_name}) + + _, filename = self.getPathFromChName(image_channel_name, posData) + fluo_data = posData.fluo_data_dict.get(filename) + if posData.SizeT > 1: + fluo_img_data = fluo_data[posData.frame_i] + else: + fluo_img_data = fluo_data + + self.SegForLostIDsWorker.inputImgForSegForLostIDs = fluo_img_data + self.SegForLostIDsWaitCond.wakeAll() + def onSigStoreData( self, waitcond, pos_i=None, enforce=True, debug=False, @@ -20641,7 +20891,10 @@ def framesScrollBarMoved(self, frame_n): def framesScrollBarReleased(self, do_store_data=False): posData = self.data[self.pos_i] - if posData.frame_i == self.navigateScrollBar.sliderPosition()-1: + if ( + posData.frame_i == self.navigateScrollBar.sliderPosition()-1 + and self.navigateScrollBarStartedMoving + ): # Slider released without changing value --> do nothing return @@ -28842,7 +29095,6 @@ def trackFrameCustomTracker( kwargs_total.update(self.track_frame_params) kwargs = {k: v for k, v in kwargs_total.items() if k in self.realTimeTracker_kwargs} - printl(kwargs) tracked_result = self.realTimeTracker.track_frame( prev_lab, currentLab, **kwargs, @@ -28897,7 +29149,6 @@ def trackFrame( return tracked_lab # get assignments - printl(assignments) if assignments is None: assignments = dict() for obj in curr_rp: diff --git a/cellacdc/load.py b/cellacdc/load.py index cea3d737..5c6f634a 100755 --- a/cellacdc/load.py +++ b/cellacdc/load.py @@ -40,7 +40,7 @@ from . import io from . import core from . import IMAGE_EXTENSIONS, VIDEO_EXTENSIONS - +from . import fonts from . import GUI_INSTALLED if GUI_INSTALLED: @@ -3207,7 +3207,7 @@ def askInputMetadata( self.SizeT, self.SizeZ, self.TimeIncrement, self.PhysicalSizeZ, self.PhysicalSizeY, self.PhysicalSizeX, ask_SizeT, ask_TimeIncrement, ask_PhysicalSizes, - parent=self.parent, font=apps.font, imgDataShape=self.img_data_shape, + parent=self.parent, font=fonts.font, imgDataShape=self.img_data_shape, posData=self, singlePos=singlePos, askSegm3D=askSegm3D, additionalValues=self._additionalMetadataValues, forceEnableAskSegm3D=forceEnableAskSegm3D, diff --git a/cellacdc/myutils.py b/cellacdc/myutils.py index 5c638c5b..51ae24e9 100644 --- a/cellacdc/myutils.py +++ b/cellacdc/myutils.py @@ -345,6 +345,7 @@ def __init__( level=logging.DEBUG ): super().__init__(f'{name}-{module}', level=level) + self.propagate = False # prevent UnicodeEncodeError via root StreamHandler self._stdout = sys.stdout self._stderr = StdErr(logger=self) sys.stderr = self._stderr @@ -368,8 +369,13 @@ def write(self, text, log_to_file=True, write_to_stdout=True): log_to_file : bool, optional If True, call `info` method with `text`. Default is True """ - if write_to_stdout: - self._stdout.write(text) + if write_to_stdout: + try: + self._stdout.write(text) + except UnicodeEncodeError: + self._stdout.write(text.encode( + self._stdout.encoding, errors='replace' + ).decode(self._stdout.encoding)) if not log_to_file: return @@ -379,8 +385,11 @@ def write(self, text, log_to_file=True, write_to_stdout=True): if not text: return - - self.debug(text) + + try: + self.debug(text) + except UnicodeEncodeError: + self.debug(text.encode('ascii', errors='replace').decode('ascii')) def close(self): for handler in self.handlers: @@ -615,7 +624,7 @@ def setupLogger(module='base', logs_path=None, caller='Cell-ACDC'): log_filename = f'{date_time}_{module}_{id}_stdout.log' log_path = os.path.join(logs_path, log_filename) - output_file_handler = logging.FileHandler(log_path, mode='w') + output_file_handler = logging.FileHandler(log_path, mode='w', encoding='utf-8') # Format your logs (optional) formatter = logging.Formatter( @@ -1729,7 +1738,7 @@ def download_java(): def get_model_path(model_name, create_temp_dir=True): if model_name == 'Automatic thresholding': - model_name == 'thresholding' + model_name = 'thresholding' model_info_path = os.path.join(cellacdc_path, 'models', model_name, 'model') @@ -4137,13 +4146,31 @@ def init_tracker( return tracker, track_params def import_segment_module(model_name): + original_model_name = model_name + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + try: acdcSegment = import_module(f'cellacdc.models.{model_name}.acdcSegment') except ModuleNotFoundError as e: + # Do not mask missing dependencies imported by the module itself. + expected_missing_module = f'cellacdc.models.{model_name}' + if e.name != expected_missing_module: + raise + # Check if custom model cp = config.ConfigParser() cp.read(models_list_file_path) - model_path = cp[model_name]['path'] + model_key = None + for key in (original_model_name, model_name): + if key in cp: + model_key = key + break + + if model_key is None: + raise + + model_path = cp[model_key]['path'] spec = importlib.util.spec_from_file_location('acdcSegment', model_path) acdcSegment = importlib.util.module_from_spec(spec) sys.modules['acdcSegment'] = acdcSegment diff --git a/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd index 0d2a5fb3..2f80c47d 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd and b/cellacdc/precompiled/precompiled_functions.cp310-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd index 8ae15d50..b581811b 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd and b/cellacdc/precompiled/precompiled_functions.cp311-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd index fc7dcabc..0bc4c832 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd and b/cellacdc/precompiled/precompiled_functions.cp312-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd index 9a7afa6e..e0178347 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd and b/cellacdc/precompiled/precompiled_functions.cp313-win_amd64.pyd differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so index d184e29e..e5356f1d 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so and b/cellacdc/precompiled/precompiled_functions.cpython-310-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so index 875d866a..e8469613 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so and b/cellacdc/precompiled/precompiled_functions.cpython-310-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so index 3a75cc44..1885f5d3 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so and b/cellacdc/precompiled/precompiled_functions.cpython-311-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so index 74008b7c..96b91d77 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so and b/cellacdc/precompiled/precompiled_functions.cpython-311-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so index 25b7b367..807043b9 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so and b/cellacdc/precompiled/precompiled_functions.cpython-312-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so index aed04d4b..14c6dd99 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so and b/cellacdc/precompiled/precompiled_functions.cpython-312-x86_64-linux-gnu.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so index 95ca4977..14da0753 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so and b/cellacdc/precompiled/precompiled_functions.cpython-313-darwin.so differ diff --git a/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so index 36620d35..f19d4d7f 100644 Binary files a/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so and b/cellacdc/precompiled/precompiled_functions.cpython-313-x86_64-linux-gnu.so differ diff --git a/cellacdc/segm.py b/cellacdc/segm.py index c75c2589..293d13ec 100755 --- a/cellacdc/segm.py +++ b/cellacdc/segm.py @@ -518,7 +518,7 @@ def main(self): model_name = win.selectedModel - if model_name == 'thresholding': + if model_name in ('thresholding', 'Automatic thresholding'): win = apps.QDialogAutomaticThresholding( parent=self, isSegm3D=self.isSegm3D ) @@ -528,6 +528,9 @@ def main(self): return self.model_kwargs = win.segment_kwargs + if model_name == 'Automatic thresholding': + model_name = 'thresholding' + self.log(f'Downloading {model_name} (if needed)...') self.downloadWin = apps.downloadModel(model_name, parent=self) self.downloadWin.download() diff --git a/cellacdc/segm_utils.py b/cellacdc/segm_utils.py index 790e5971..5a44a202 100644 --- a/cellacdc/segm_utils.py +++ b/cellacdc/segm_utils.py @@ -38,6 +38,22 @@ def find_overlap(lab_1, lab_2): return ID_overlap +def get_best_overlapping_label(label_img, obj, allowed_labels): + allowed_labels = set(allowed_labels) + if len(allowed_labels) == 0: + return None + + overlapping_labels = label_img[obj.slice][obj.image] + if overlapping_labels.size == 0: + return None + + overlapping_labels = overlapping_labels[np.isin(overlapping_labels, tuple(allowed_labels))] + if overlapping_labels.size == 0: + return None + + labels, counts = np.unique(overlapping_labels, return_counts=True) + return labels[np.argmax(counts)] + def get_obj_from_rps(rps, ID): for obj in rps: if obj.label == ID: @@ -164,9 +180,15 @@ def boxes_overlap(bbox1, bbox2): # return np.unique(border_labels[border_labels != 0]) def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, - win, posData, distance_filler_growth=1, + posData, distance_filler_growth=1, overlap_threshold=0.5, padding=0.4, export_bbox_for_training=False, + model_kwargs=None, + preproc_recipe=None, + applyPostProcessing=False, + standardPostProcessKwargs=None, + customPostProcessFeatures=None, + customPostProcessGroupedFeatures=None, ): """ Function to segment single cells in the current frame using the previous frame segmentation as a reference. @@ -178,12 +200,17 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, curr_img: current frame image IDs: list of IDs of the cells to segment new_unique_ID: ID to start labeling new cells - win: from the gui window which sets model params posData: position data (see rest of acdc) distance_filler_growth: distance to grow the other IDs to fill the background overlap_threshold: minimum overlap percentage to consider a cell already segmented padding: padding around the cell to segment export_bbox_for_training: if True, export bounding boxes for training model + model_kwargs: keyword arguments to pass to the segmentation model + preproc_recipe: preprocessing recipe to apply before segmentation + applyPostProcessing: if True, apply post-processing to the segmentation + standardPostProcessKwargs: keyword arguments for standard post-processing + customPostProcessFeatures: custom features for post-processing segmentation + customPostProcessGroupedFeatures: custom grouped features for post-processing Returns: curr_lab: current frame segmentation with the segmented cells @@ -194,12 +221,10 @@ def single_cell_seg(model, prev_lab, curr_lab, curr_img, IDs, new_unique_ID, if export_bbox_for_training: bboxs_for_debug = [] - model_kwargs = win.model_kwargs - preproc_recipe = win.preproc_recipe - applyPostProcessing = win.applyPostProcessing - standardPostProcessKwargs = win.standardPostProcessKwargs - customPostProcessFeatures = win.customPostProcessFeatures - customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + if model_kwargs is None: + model_kwargs = {} + if standardPostProcessKwargs is None: + standardPostProcessKwargs = {} prev_rp = skimage.measure.regionprops(prev_lab) prev_lab_shape = prev_lab.shape diff --git a/cellacdc/widgets.py b/cellacdc/widgets.py index b22e1cec..bb247e50 100755 --- a/cellacdc/widgets.py +++ b/cellacdc/widgets.py @@ -71,6 +71,7 @@ from . import _core, core from . import QtScoped from . import prompts +from . import fonts from .acdc_regex import float_regex from .config import PREPROCESS_MAPPER from . import _base_widgets @@ -84,8 +85,7 @@ PROGRESSBAR_HIGHLIGHTEDTEXT_QCOLOR = _palettes.QProgressBarHighlightedTextColor() TEXT_COLOR = _palettes.text_float_rgba() -font = QFont() -font.setPixelSize(12) +font = fonts.font custom_cmaps_filepath = os.path.join(settings_folderpath, 'custom_colormaps.ini') @@ -1075,7 +1075,7 @@ class _ReorderableListModel(QAbstractListModel): def __init__(self, items, parent=None): QAbstractItemModel.__init__(self, parent) - self.nodes = items + self.nodes = list(items) self.lastDroppedItems = [] self.pendingRemoveRowsAfterDrop = False @@ -1286,7 +1286,7 @@ def __init__( self.setStyleSheet(styleSheet) def setItems(self, items): - self._model.nodes = items + self._model.nodes = list(items) def items(self): return self._model.nodes @@ -1464,10 +1464,9 @@ def warnSelectionEmpty(self): def ok_cb(self, checked=False): self.clickedButton = self.sender() - self.cancel = False selectedItems = self.listBox.selectedItems() - self.selectedItemsText = [item.text() for item in selectedItems] - if not self.allowSingleSelection and len(self.selectedItemsText) < 2: + selectedItemsText = [item.text() for item in selectedItems] + if not self.allowSingleSelection and len(selectedItemsText) < 2: msg = myMessageBox(wrapText=False, showCentered=False) txt = html_utils.paragraph( 'You need to select two or more items.

' @@ -1477,9 +1476,12 @@ def ok_cb(self, checked=False): msg.warning(self, 'Select two or more items', txt) return - if not self.allowEmptySelection and not self.selectedItemsText: + if not self.allowEmptySelection and not selectedItemsText: self.warnSelectionEmpty() return + + self.cancel = False + self.selectedItemsText = selectedItemsText self.sigSelectionConfirmed.emit(self.selectedItemsText) self.close() @@ -12083,4 +12085,305 @@ def closeEvent(self, event): if self.screenShotWin is not None: self.screenShotWin.close() - return super().closeEvent(event) \ No newline at end of file + return super().closeEvent(event) + + +class MultiPickListWidget(QWidget): + """Generic list widget with multi-pick (repeated-selection) support. + + Each pickable row shows ``- count +`` controls. Left-clicking adds + one instance; right-clicking or Ctrl+left-click removes one. The same + item can appear multiple times in :attr:`selectionSequence`. + + Parameters + ---------- + items: + Initial list of item labels. + excludedItems: + Labels that are shown in the list but *not* given +/- controls + (e.g. placeholder entries like "Add custom model…"). Click events + on these are silently ignored. + parent: + Optional parent widget. + """ + + sigSelectionChanged = Signal(list) # emits selectionSequence on every change + + def __init__(self, items=None, excludedItems=None, parent=None): + super().__init__(parent) + + self._excludedItems = set(excludedItems or []) + self._itemsMap = {} # label → QListWidgetItem + self._countMap = defaultdict(int) + self._countLabelMap = {} + self.selectionSequence = [] + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + self.listBox = listWidget(isMultipleSelection=False) + self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + + for label in (items or []): + self._addListItem(label) + + if self._itemsMap: + self.listBox.setCurrentRow(0) + + self.listBox.itemClicked.connect(self._onItemClicked) + self.listBox.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.listBox.customContextMenuRequested.connect(self._onRightClick) + + # self.listBox.setStyleSheet(LISTWIDGET_STYLESHEET) + # self.setStyleSheet(LISTWIDGET_STYLESHEET) + layout.addWidget(self.listBox) + + + @property + def itemsMap(self): + """Dict mapping label → QListWidgetItem for all pickable items.""" + return dict(self._itemsMap) + + def currentItemName(self): + """Return the label of the currently highlighted item, or ``None``.""" + item = self.listBox.currentItem() + return item.text() if item is not None else None + + def addSelection(self, label): + """Add one instance of *label* to the selection.""" + if label not in self._itemsMap: + return + self.selectionSequence.append(label) + self._countMap[label] += 1 + self._updateCountLabel(label) + self.listBox.setCurrentItem(self._itemsMap[label]) + self.sigSelectionChanged.emit(list(self.selectionSequence)) + + def removeSelection(self, label): + """Remove the last instance of *label* from the selection.""" + if self._countMap.get(label, 0) <= 0: + return + for i in range(len(self.selectionSequence) - 1, -1, -1): + if self.selectionSequence[i] == label: + self.selectionSequence.pop(i) + break + self._countMap[label] = max(0, self._countMap[label] - 1) + self._updateCountLabel(label) + self.sigSelectionChanged.emit(list(self.selectionSequence)) + + def resetSelection(self): + """Clear all selections and reset all counters to zero.""" + self.selectionSequence = [] + self._countMap = defaultdict(int) + for label in self._countLabelMap: + self._updateCountLabel(label) + self.sigSelectionChanged.emit([]) + + def setSelectionFromList(self, labels): + """Set the selection to *labels* (duplicates supported).""" + self.resetSelection() + for label in labels: + self.addSelection(label) + + def registerItem(self, label, insertBeforeLabel=None): + """Dynamically add a new pickable item. + + Parameters + ---------- + label: + Text for the new item. + insertBeforeLabel: + If given, insert the new item immediately before this label. + If not found or not given, the item is appended. + + Returns the created ``QListWidgetItem``. + """ + if label in self._itemsMap: + return self._itemsMap[label] + + item = QListWidgetItem(label) + + if insertBeforeLabel is not None: + target = self._itemsMap.get(insertBeforeLabel) + if target is None: + for row in range(self.listBox.count()): + row_item = self.listBox.item(row) + if row_item is not None and row_item.text() == insertBeforeLabel: + target = row_item + break + if target is not None: + row = self.listBox.row(target) + self.listBox.insertItem(row, item) + else: + self.listBox.addItem(item) + else: + self.listBox.addItem(item) + + self._itemsMap[label] = item + self._addCounterWidget(label, item) + return item + + def _addListItem(self, label): + """Create a QListWidgetItem and, if pickable, attach a counter widget.""" + item = QListWidgetItem(label) + self.listBox.addItem(item) + if label not in self._excludedItems: + self._itemsMap[label] = item + self._addCounterWidget(label, item) + + def _addCounterWidget(self, label, item): + rowWidget = QWidget() + rowLayout = QHBoxLayout(rowWidget) + rowLayout.setContentsMargins(4, 0, 4, 0) + rowLayout.setSpacing(6) + + nameLabelPlaceholder = QSpacerItem(2, 0) + minusBtn = QPushButton('-') + plusBtn = QPushButton('+') + countLabel = QLabel(str(self._countMap.get(label, 0))) + + minusBtn.setFixedWidth(24) + plusBtn.setFixedWidth(24) + countLabel.setMinimumWidth(20) + countLabel.setAlignment(Qt.AlignCenter) + + minusBtn.clicked.connect(lambda _, lbl=label: self.removeSelection(lbl)) + plusBtn.clicked.connect(lambda _, lbl=label: self.addSelection(lbl)) + + rowLayout.addItem(nameLabelPlaceholder) + rowLayout.addStretch(1) + rowLayout.addWidget(minusBtn) + rowLayout.addWidget(countLabel) + rowLayout.addWidget(plusBtn) + + self._countLabelMap[label] = countLabel + self._updateCountLabel(label) + self.listBox.setItemWidget(item, rowWidget) + + def _updateCountLabel(self, label): + lbl = self._countLabelMap.get(label) + if lbl is not None: + count = self._countMap.get(label, 0) + lbl.setText(str(count)) + if count <= 0: + lbl.setStyleSheet('color: gray;') + else: + lbl.setStyleSheet('') + + def _onItemClicked(self, item): + label = item.text() + if label in self._excludedItems: + return + modifiers = QApplication.keyboardModifiers() + if modifiers & Qt.ControlModifier: + self.removeSelection(label) + else: + self.addSelection(label) + + def _onRightClick(self, pos): + item = self.listBox.itemAt(pos) + if item is None: + return + label = item.text() + if label in self._excludedItems: + return + self.removeSelection(label) + + +class ModelSelectionWidget(QWidget): + """List widget for selecting segmentation models. + + Thin wrapper around :class:`MultiPickListWidget` that populates the list + with the installed models and adds a special "Add custom model…" entry. + + ``sigSelectionChanged`` and ``selectionSequence`` are proxied from the + underlying :class:`MultiPickListWidget`. + """ + + _ADD_CUSTOM = 'Add custom model...' + + sigSelectionChanged = Signal(list) + + def __init__(self, parent=None, customFirst='', allowMultiSelection=False): + super().__init__(parent) + + self.allowMultiSelection = allowMultiSelection + + models = myutils.get_list_of_models() + if customFirst: + try: + models.insert(0, models.pop(models.index(customFirst))) + except ValueError: + pass + + items = models + [self._ADD_CUSTOM] + + layout = QVBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + + if allowMultiSelection: + items = models + self._picker = MultiPickListWidget( + items=items, + excludedItems=[self._ADD_CUSTOM], + parent=self, + ) + self._picker.listBox.setFont(font) + self._picker.sigSelectionChanged.connect(self.sigSelectionChanged) + self.listBox = self._picker.listBox + layout.addWidget(self._picker) + else: + self.listBox = listWidget(isMultipleSelection=False) + self.listBox.setFont(font) + self.listBox.addItems(models) + add_item = QListWidgetItem(self._ADD_CUSTOM) + add_item.setFont(fonts.italicFont) + self.listBox.addItem(add_item) + self.listBox.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection) + self.listBox.setCurrentRow(0) + self._picker = None + layout.addWidget(self.listBox) + + # ------------------------------------------------------------------ + # Proxy helpers (multi-selection mode only) + # ------------------------------------------------------------------ + + @property + def selectionSequence(self): + return self._picker.selectionSequence if self._picker is not None else [] + + @property + def modelItemsMap(self): + return self._picker.itemsMap if self._picker is not None else {} + + def currentModelName(self): + if self._picker is not None: + return self._picker.currentItemName() + item = self.listBox.currentItem() + return item.text() if item is not None else None + + def addModelSelection(self, name): + if self._picker is not None: + self._picker.addSelection(name) + + def removeModelSelection(self, name): + if self._picker is not None: + self._picker.removeSelection(name) + + def resetSelectionSequence(self): + if self._picker is not None: + self._picker.resetSelection() + + def setSelectionFromList(self, models): + if self._picker is not None: + self._picker.setSelectionFromList(models) + + def registerCustomModel(self, model_name): + """Add a newly registered custom model and return its item.""" + if self._picker is not None: + return self._picker.registerItem( + model_name, insertBeforeLabel=self._ADD_CUSTOM + ) + item = QListWidgetItem(model_name) + self.listBox.insertItem(self.listBox.count() - 1, item) + return item diff --git a/cellacdc/workers.py b/cellacdc/workers.py index 6127d623..29f8e7eb 100755 --- a/cellacdc/workers.py +++ b/cellacdc/workers.py @@ -43,7 +43,7 @@ from . import cli from .utils import resize from . import segm_utils - +from . import regionprops DEBUG = False def worker_exception_handler(func): @@ -198,10 +198,11 @@ def run(self): class SegForLostIDsWorker(QObject): sigAskInit = Signal() - sigAskInstallModel = Signal(str) + # sigAskInstallModel = Signal(str) sigshowImageDebug = Signal(object) sigStoreData = Signal(bool) sigUpdateRP = Signal(bool, bool) + sigGetSegForLostIDsInputImg = Signal(str) # sigGetData = Signal() # sigGet2Dlab = Signal() # sigGetTrackedLostIDs = Signal() @@ -217,6 +218,7 @@ def __init__(self, guiWin, mutex, waitCond, debug=False): self.mutex = mutex self.waitCond = waitCond self._debug = debug + self.inputImgForSegForLostIDs = None def emitSigAskInit(self): self.mutex.lock() @@ -242,17 +244,26 @@ def emitSigUpdateRP(self, wl_track_og_curr, wl_update): self.waitCond.wait(self.mutex) self.mutex.unlock() + def emitGetSegForLostIDsInputImg(self, image_channel_name): + self.mutex.lock() + self.sigGetSegForLostIDsInputImg.emit(image_channel_name) + self.waitCond.wait(self.mutex) + img = self.inputImgForSegForLostIDs + self.inputImgForSegForLostIDs = None + self.mutex.unlock() + return img + # def emitSigGetData(self): # self.mutex.lock() # self.sigGetData.emit() # self.waitCond.wait(self.mutex) # self.mutex.unlock() - def emitSigAskInstallModel(self, model_name): - self.mutex.lock() - self.sigAskInstallModel.emit(model_name) - self.waitCond.wait(self.mutex) - self.mutex.unlock() + # def emitSigAskInstallModel(self, model_name): + # self.mutex.lock() + # self.sigAskInstallModel.emit(model_name) + # self.waitCond.wait(self.mutex) + # self.mutex.unlock() def emitSigAskInstallGPU(self, base_model_name, use_gpu): self.mutex.lock() @@ -298,36 +309,58 @@ def run(self): return self.logger.info('Segmentation for lost IDs started.') - model_name = 'local_seg' - base_model_name = self.guiWin.SegForLostIDsSettings['base_model_name'] - idx = self.guiWin.modelNames.index(model_name) - acdcSegment = self.guiWin.acdcSegment_li[idx] - - init_kwargs = self.guiWin.SegForLostIDsSettings['win'].init_kwargs - - use_gpu = init_kwargs.get('device_type', 'cpu') != 'cpu' - use_gpu = use_gpu or init_kwargs.get('use_gpu', False) - - self.emitSigAskInstallGPU(base_model_name, use_gpu) - - if not self.gpu_go: - self.signals.finished.emit(self) - return - - if not self.dont_force_cpu: - if 'device' in init_kwargs: - init_kwargs['device'] = 'cpu' - if 'use_gpu' in init_kwargs: - init_kwargs['use_gpu'] = False - if acdcSegment is None or base_model_name != self.guiWin.local_seg_base_model_name: + model_settings = self.guiWin.SegForLostIDsSettings['models_settings'] + + n_models = len(model_settings) + total_steps = 2 * n_models + self.signals.initProgressBar.emit(total_steps) + + for model_idx, model_settings_i in enumerate(model_settings): + base_model_name = model_settings_i['base_model_name'] + init_kwargs_new = dict(model_settings_i['init_kwargs_new']) + image_channel_name = init_kwargs_new.pop( + 'image_channel_name', 'Displayed image' + ) + args_new = model_settings_i['args_new'] + init_kwargs = model_settings_i.get('init_kwargs', {}) + model_kwargs = model_settings_i.get('model_kwargs', {}) + preproc_recipe = model_settings_i.get('preproc_recipe', None) + applyPostProcessing = model_settings_i.get('applyPostProcessing', False) + standardPostProcessKwargs = model_settings_i.get('standardPostProcessKwargs', {}) + customPostProcessFeatures = model_settings_i.get('customPostProcessFeatures', None) + customPostProcessGroupedFeatures = model_settings_i.get('customPostProcessGroupedFeatures', None) + + # Fall back to reading from the live win object when available + win = model_settings_i.get('win') + if win is not None: + init_kwargs = win.init_kwargs + model_kwargs = win.model_kwargs + preproc_recipe = win.preproc_recipe + applyPostProcessing = win.applyPostProcessing + standardPostProcessKwargs = win.standardPostProcessKwargs + customPostProcessFeatures = win.customPostProcessFeatures + customPostProcessGroupedFeatures = win.customPostProcessGroupedFeatures + + use_gpu = init_kwargs.get('device_type', 'cpu') != 'cpu' + use_gpu = use_gpu or init_kwargs.get('use_gpu', False) + + self.emitSigAskInstallGPU(base_model_name, use_gpu) + + if not self.gpu_go: + self.signals.finished.emit(self) + return + + if not self.dont_force_cpu: + if 'device' in init_kwargs: + init_kwargs_new = dict(init_kwargs_new, device='cpu') + if 'use_gpu' in init_kwargs: + init_kwargs_new = dict(init_kwargs_new, use_gpu=False) + try: self.logger.info(f'Importing {base_model_name}...') - self.emitSigAskInstallModel(base_model_name) acdcSegment = myutils.import_segment_module(base_model_name) - self.guiWin.acdcSegment_li[idx] = acdcSegment - self.guiWin.local_seg_base_model_name = base_model_name - except (IndexError, ImportError, KeyError) as e: + except (IndexError, ImportError, KeyError): self.logger.warning( f'Cannot import {base_model_name} model. ' 'Please install it first.' @@ -339,122 +372,157 @@ def run(self): self.signals.finished.emit(self) return - win = self.guiWin.SegForLostIDsSettings['win'] - init_kwargs_new = self.guiWin.SegForLostIDsSettings['init_kwargs_new'] - args_new = self.guiWin.SegForLostIDsSettings['args_new'] - - model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new) - if model is None: - self.logger.info('Segmentation model was not initialized correctly!') - self.signals.critical.emit( - (self, 'Segmentation model was not initialized correctly!') - ) - self.signals.finished.emit(self) - return - if self._debug: - try: - model.setupLogger(self.guiwin.logger) - except Exception as e: - pass + model = myutils.init_segm_model(acdcSegment, posData, init_kwargs_new) + if model is None: + self.logger.info('Segmentation model was not initialized correctly!') + self.signals.critical.emit( + (self, 'Segmentation model was not initialized correctly!') + ) + self.signals.finished.emit(self) + return + if self._debug: + try: + model.setupLogger(self.guiWin.logger) + except Exception: + pass - assigned_IDs = [] - missing_IDs_global = set() - original_lab = posData.lab.copy() - IDs_bboxs_list = [] - bboxs_list = [] + assigned_IDs = [] + missing_IDs_global = set() + original_lab = posData.lab.copy() + IDs_bboxs_list = [] + bboxs_list = [] - curr_img = self.guiWin.getDisplayedImg1() - prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) - prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set + curr_img = self.emitGetSegForLostIDsInputImg(image_channel_name) + if curr_img is None: + self.signals.critical.emit( + (self, 'Could not get input image for SegForLostIDsWorker') + ) + self.signals.finished.emit(self) + return + prev_lab = self.guiWin.get_2Dlab(posData.allData_li[frame_i-1]['labels']) + prev_IDs = posData.allData_li[frame_i-1]['regionprops'].IDs_set - # should probably not paly so much with posData.lab, instead handle stuff myself - self.signals.initProgressBar.emit(2 * args_new['max_iterations']) - new_labs = np.zeros([args_new['max_iterations'], *posData.lab.shape], dtype=np.uint32) - for i in range(args_new['max_iterations']): + new_labs = [] curr_lab = self.guiWin.get_2Dlab(posData.lab) tracked_lost_IDs = self.guiWin.getTrackedLostIDs() new_unique_ID = self.guiWin.setBrushID(useCurrentLab=True, return_val=True) - missing_IDs = prev_IDs - set(posData.IDs) - set(tracked_lost_IDs) + missing_IDs = prev_IDs - posData.rp.IDs_set - set(tracked_lost_IDs) missing_IDs_global.update(missing_IDs) assigned_IDs_prev = assigned_IDs.copy() out = segm_utils.single_cell_seg( - model, prev_lab, curr_lab, curr_img, + model, prev_lab, curr_lab, curr_img, missing_IDs, new_unique_ID, - win, posData, + posData, distance_filler_growth=args_new['distance_filler_growth'], overlap_threshold=args_new['overlap_threshold'], padding=args_new['padding'], + model_kwargs=model_kwargs, + preproc_recipe=preproc_recipe, + applyPostProcessing=applyPostProcessing, + standardPostProcessKwargs=standardPostProcessKwargs, + customPostProcessFeatures=customPostProcessFeatures, + customPostProcessGroupedFeatures=customPostProcessGroupedFeatures, ) new_lab, assigned_IDs, IDs_bboxs, bboxs = out - + IDs_bboxs_list.append(IDs_bboxs) bboxs_list.append(bboxs) posData.lab = new_lab self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) newly_assigned_IDs = set(assigned_IDs) - set(assigned_IDs_prev) self.emitTrackManuallyAddedObject(newly_assigned_IDs, True, False, False) - new_labs[i] = posData.lab.copy() + new_labs.append(posData.lab.copy()) self.signals.progressBar.emit(1) - - if self._debug: - originals = [] - models = [] - posData.lab = original_lab.copy() - - global_area_mean = np.mean([obj.area for obj in posData.rp]) - for IDs_bboxs, bboxs in zip(IDs_bboxs_list, bboxs_list): - model_lab = new_labs[i] if self._debug: - originals.append(original_lab.copy()) - models.append(posData.lab.copy()) + originals = [] + models = [] + + posData.lab = original_lab.copy() + self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) + + global_areas = [obj.area for obj in posData.rp] + global_area_mean = np.mean(global_areas) if len(global_areas) > 0 else None + for i, (IDs_bboxs, bboxs) in enumerate(zip(IDs_bboxs_list, bboxs_list)): + model_lab = new_labs[i] + if self._debug: + originals.append(original_lab.copy()) + models.append(posData.lab.copy()) - for IDs, bbox in zip(IDs_bboxs, bboxs): + for IDs, bbox in zip(IDs_bboxs, bboxs): - box_x_min, box_x_max, box_y_min, box_y_max = bbox - original_bbox_lab = original_lab[box_x_min:box_x_max, box_y_min:box_y_max] - original_bbox_lab_cleared_borders = skimage.segmentation.clear_border(original_bbox_lab) - box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] + box_x_min, box_x_max, box_y_min, box_y_max = bbox + original_bbox_lab = original_lab[box_x_min:box_x_max, box_y_min:box_y_max] + original_bbox_lab_cleared_borders = skimage.segmentation.clear_border(original_bbox_lab) + box_model_lab = model_lab[box_x_min:box_x_max, box_y_min:box_y_max] # original_bbox_lab[np.isin(original_bbox_lab, IDs)] = 0 should be a given. If not seg for lost IDs this recommended - box_model_lab = skimage.segmentation.clear_border(box_model_lab, buffer_size=1) + box_model_lab = skimage.segmentation.clear_border(box_model_lab, buffer_size=1) - rp_model_lab = skimage.measure.regionprops(box_model_lab) - rp_original_lab = skimage.measure.regionprops(original_bbox_lab) - rp_original_lab_cleared = skimage.measure.regionprops(original_bbox_lab_cleared_borders) + rp_model_lab = regionprops.acdcRegionprops(box_model_lab, precache_centroids=False) + rp_original_lab = regionprops.acdcRegionprops(original_bbox_lab, precache_centroids=False) + rp_original_lab_cleared = regionprops.acdcRegionprops(original_bbox_lab_cleared_borders, precache_centroids=False) - original_IDs = [obj.label for obj in rp_original_lab] - areas = [obj.area for obj in rp_original_lab_cleared] - if len(areas) > 0: - area_mean = np.mean(areas) - else: - area_mean = global_area_mean - if args_new['allow_only_tracked_cells']: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean + original_IDs = [obj.label for obj in rp_original_lab] + areas = [obj.area for obj in rp_original_lab_cleared] + if len(areas) > 0: + area_mean = np.mean(areas) + elif global_area_mean is not None: + area_mean = global_area_mean + else: + model_areas = [obj.area for obj in rp_model_lab] + area_mean = np.mean(model_areas) if len(model_areas) > 0 else None + + skip_size_filter = area_mean is None + if not skip_size_filter: + min_area = (1 - args_new['size_perc_diff']) * area_mean + max_area = (1 + args_new['size_perc_diff']) * area_mean + + prev_bbox_lab = prev_lab[box_x_min:box_x_max, box_y_min:box_y_max] + relabeled_IDs = {} + if args_new['allow_only_tracked_cells']: + filtered_IDs = [] + for obj in rp_model_lab: + if not (skip_size_filter or (obj.area > min_area and obj.area < max_area)): + continue + if obj.label in original_IDs: + continue + + target_ID = segm_utils.get_best_overlapping_label( + prev_bbox_lab, + obj, + missing_IDs_global, + ) + if target_ID is None: + continue + + filtered_IDs.append(obj.label) + relabeled_IDs[obj.label] = target_ID + else: + filtered_IDs = [ + obj.label for obj in rp_model_lab + if (skip_size_filter or (obj.area > min_area and obj.area < max_area)) and obj.label not in original_IDs - and obj.label in missing_IDs_global] - else: - filtered_IDs = [obj.label for obj in rp_model_lab - if obj.area > (1 - args_new['size_perc_diff']) * area_mean - and obj.area < (1 + args_new['size_perc_diff']) * area_mean - and obj.label not in original_IDs] - - if self._debug or DEBUG: - filtered_sizes = [(obj.label, obj.area) for obj in rp_model_lab if obj.label in filtered_IDs] - self.logger.info(f"Filtered sizes: {filtered_sizes}") - for label in filtered_IDs: - original_bbox_lab[box_model_lab == label] = label # here the stuff should be tracked, so we keep the ID! + ] + + if self._debug or DEBUG: + filtered_sizes = [(obj.label, obj.area) for obj in rp_model_lab if obj.label in filtered_IDs] + self.logger.info(f"Filtered sizes: {filtered_sizes}") + for label in filtered_IDs: + obj = rp_model_lab.get_obj_from_ID(label) + target_label = relabeled_IDs.get(label, label) + + original_bbox_lab[obj.slice][obj.image] = target_label # original_lab[box_x_min:box_x_max, box_y_min:box_y_max] = original_bbox_lab - - self.signals.progressBar.emit(1) - - posData.lab = original_lab + + self.signals.progressBar.emit(1) + + posData.lab = original_lab + self.emitSigUpdateRP(wl_update=True, wl_track_og_curr=False) # if self._debug: # originals = np.concatenate(originals, axis=0) diff --git a/tests/test_segm_utils.py b/tests/test_segm_utils.py new file mode 100644 index 00000000..4617b059 --- /dev/null +++ b/tests/test_segm_utils.py @@ -0,0 +1,174 @@ +import types + +import numpy as np +from skimage.measure import regionprops +from types import SimpleNamespace + +from cellacdc import myutils +from cellacdc.regionprops import acdcRegionprops +from cellacdc.workers import SegForLostIDsWorker + +from cellacdc.models.thresholding.acdcSegment import Model as ThresholdingModel +from cellacdc.segm_utils import get_best_overlapping_label + + +def test_get_best_overlapping_label_uses_majority_overlap_with_allowed_labels(): + label_img = np.zeros((8, 8), dtype=np.uint16) + label_img[2:6, 2:4] = 4 + label_img[3:7, 4:6] = 7 + + obj = types.SimpleNamespace( + slice=(slice(2, 7), slice(2, 6)), + image=np.array( + [ + [False, False, False, False], + [True, True, False, False], + [True, True, True, True], + [True, True, True, True], + [False, False, True, True], + ] + ), + ) + + assert get_best_overlapping_label(label_img, obj, {4, 7}) == 4 + + +def test_get_best_overlapping_label_returns_none_without_allowed_overlap(): + label_img = np.zeros((6, 6), dtype=np.uint16) + label_img[1:3, 1:3] = 2 + + obj = types.SimpleNamespace( + slice=(slice(1, 4), slice(1, 4)), + image=np.array( + [ + [False, True, False], + [True, True, True], + [False, True, False], + ] + ), + ) + + assert get_best_overlapping_label(label_img, obj, {5}) is None + + +def test_thresholding_model_object_can_be_mapped_back_to_missing_id(): + prev_lab = np.zeros((10, 10), dtype=np.uint16) + prev_lab[3:7, 3:7] = 5 + + image = np.zeros((10, 10), dtype=np.float32) + image[3:7, 3:7] = 10.0 + + model = ThresholdingModel() + model_lab = model.segment( + image, + gauss_sigma=0, + threshold_method='threshold_otsu', + ) + + rp_model = regionprops(model_lab) + assert len(rp_model) == 1 + + recovered_id = get_best_overlapping_label(prev_lab, rp_model[0], {5}) + + assert recovered_id == 5 + + +class _DummyLogger: + def info(self, message): + pass + + def warning(self, message): + pass + + def error(self, message): + pass + + +class _DummySignal: + def emit(self, *args, **kwargs): + pass + + +class _DummySignals: + def __init__(self): + self.progress = _DummySignal() + self.finished = _DummySignal() + self.initProgressBar = _DummySignal() + self.progressBar = _DummySignal() + self.critical = _DummySignal() + + +def test_seg_for_lost_ids_worker_thresholding_relabels_recovered_object(monkeypatch): + prev_lab = np.zeros((10, 10), dtype=np.uint16) + prev_lab[3:7, 3:7] = 5 + + curr_lab = np.zeros((10, 10), dtype=np.uint16) + curr_img = np.zeros((10, 10), dtype=np.float32) + curr_img[3:7, 3:7] = 10.0 + + prev_rp = acdcRegionprops(prev_lab) + curr_rp = acdcRegionprops(curr_lab) + + posData = SimpleNamespace( + frame_i=1, + lab=curr_lab.copy(), + rp=curr_rp, + allData_li=[ + {'labels': prev_lab, 'regionprops': prev_rp}, + {'labels': curr_lab, 'regionprops': curr_rp}, + ], + ) + + guiWin = SimpleNamespace( + data=[posData], + pos_i=0, + SegForLostIDsSettings={ + 'models_settings': [ + { + 'base_model_name': 'thresholding', + 'init_kwargs_new': {}, + 'args_new': { + 'distance_filler_growth': 1.0, + 'overlap_threshold': 0.5, + 'padding': 1.0, + 'size_perc_diff': 1.0, + 'allow_only_tracked_cells': True, + }, + 'init_kwargs': {}, + 'model_kwargs': { + 'gauss_sigma': 0, + 'threshold_method': 'threshold_otsu', + }, + 'preproc_recipe': None, + 'applyPostProcessing': False, + 'standardPostProcessKwargs': {}, + 'customPostProcessFeatures': None, + 'customPostProcessGroupedFeatures': None, + } + ] + }, + getDisplayedImg1=lambda: curr_img, + get_2Dlab=lambda lab: lab, + getTrackedLostIDs=lambda: [], + setBrushID=lambda useCurrentLab=True, return_val=True: 10, + logger=_DummyLogger(), + ) + + worker = SegForLostIDsWorker(guiWin, mutex=SimpleNamespace(lock=lambda: None, unlock=lambda: None), waitCond=SimpleNamespace(wait=lambda mutex: None)) + worker.signals = _DummySignals() + worker.logger = _DummyLogger() + worker.gpu_go = True + worker.dont_force_cpu = True + + monkeypatch.setattr(worker, 'emitSigAskInit', lambda: None) + monkeypatch.setattr(worker, 'emitSigAskInstallGPU', lambda base_model_name, use_gpu: None) + monkeypatch.setattr(worker, 'emitSigUpdateRP', lambda wl_update=True, wl_track_og_curr=False: None) + monkeypatch.setattr(worker, 'emitSigStoreData', lambda autosave=True: None) + monkeypatch.setattr(worker, 'emitTrackManuallyAddedObject', lambda *args, **kwargs: None) + monkeypatch.setattr(myutils, 'import_segment_module', lambda base_model_name: SimpleNamespace(Model=ThresholdingModel)) + monkeypatch.setattr(myutils, 'init_segm_model', lambda acdcSegment, posData, init_kwargs_new: ThresholdingModel()) + + worker.run() + + assert posData.lab[3:7, 3:7].min() == 5 + assert posData.lab[3:7, 3:7].max() == 5 \ No newline at end of file