diff --git a/luxonis_train/config/predefined_models/base_predefined_model.py b/luxonis_train/config/predefined_models/base_predefined_model.py index f8a820e7..3adf676d 100644 --- a/luxonis_train/config/predefined_models/base_predefined_model.py +++ b/luxonis_train/config/predefined_models/base_predefined_model.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from contextlib import suppress from typing import Literal from loguru import logger @@ -208,42 +209,47 @@ def nodes(self) -> list[NodeConfig]: return nodes def _generate_metrics(self) -> list[MetricModuleConfig]: - if self._per_class_metrics is None: - return [ - MetricModuleConfig( - name=metric, - params=self._metrics_params, - is_main_metric=metric == self._main_metric, - ) - for metric in self._metrics - ] - - task = NODES.get(self._head).task - metrics = [] - applied_per_class_override = False - - for metric in self._metrics: - metric_params = dict(self._metrics_params) - metric_cls = METRICS.get(metric) - aliases = metric_cls.get_predefined_model_params_aliases(task) - param_name = aliases.get("per_class_metrics") - if param_name is not None: - metric_params[param_name] = self._per_class_metrics - applied_per_class_override = True - - metrics.append( - MetricModuleConfig( - name=metric, - params=metric_params, - is_main_metric=metric == self._main_metric, - ) - ) + if self._per_class_metrics is not None: + try: + task = NODES.get(self._head).task + metrics = [] + applied_per_class_override = False + + for metric in self._metrics: + metric_params = dict(self._metrics_params) + metric_cls = METRICS.get(metric) + aliases = metric_cls.get_predefined_model_params_aliases( + task + ) + param_name = aliases.get("per_class_metrics") + if param_name is not None: + metric_params[param_name] = self._per_class_metrics + applied_per_class_override = True - if self._metrics and not applied_per_class_override: - logger.warning( - "Ignoring `per_class_metrics` for predefined model metrics " - f"{self._metrics} because none of them support a per-class " - "override." - ) + metrics.append( + MetricModuleConfig( + name=metric, + params=metric_params, + is_main_metric=metric == self._main_metric, + ) + ) + + if self._metrics and not applied_per_class_override: + logger.warning( + "Ignoring `per_class_metrics` for predefined model metrics " + f"{self._metrics} because none of them support a per-class " + "override." + ) + return metrics # noqa: TRY300 - return metrics + except Exception: + logger.warning("Unable to ") + + return [ + MetricModuleConfig( + name=metric, + params=self._metrics_params, + is_main_metric=metric == self._main_metric, + ) + for metric in self._metrics + ]