Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions luxonis_train/config/predefined_models/base_predefined_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from contextlib import suppress
from typing import Literal

from loguru import logger
Expand Down Expand Up @@ -208,42 +209,47 @@ def nodes(self) -> list[NodeConfig]:
return nodes

def _generate_metrics(self) -> list[MetricModuleConfig]:
if self._per_class_metrics is None:
return [
MetricModuleConfig(
name=metric,
params=self._metrics_params,
is_main_metric=metric == self._main_metric,
)
for metric in self._metrics
]

task = NODES.get(self._head).task
metrics = []
applied_per_class_override = False

for metric in self._metrics:
metric_params = dict(self._metrics_params)
metric_cls = METRICS.get(metric)
aliases = metric_cls.get_predefined_model_params_aliases(task)
param_name = aliases.get("per_class_metrics")
if param_name is not None:
metric_params[param_name] = self._per_class_metrics
applied_per_class_override = True

metrics.append(
MetricModuleConfig(
name=metric,
params=metric_params,
is_main_metric=metric == self._main_metric,
)
)
if self._per_class_metrics is not None:
try:
task = NODES.get(self._head).task
metrics = []
applied_per_class_override = False

for metric in self._metrics:
metric_params = dict(self._metrics_params)
metric_cls = METRICS.get(metric)
aliases = metric_cls.get_predefined_model_params_aliases(
task
)
param_name = aliases.get("per_class_metrics")
if param_name is not None:
metric_params[param_name] = self._per_class_metrics
applied_per_class_override = True

if self._metrics and not applied_per_class_override:
logger.warning(
"Ignoring `per_class_metrics` for predefined model metrics "
f"{self._metrics} because none of them support a per-class "
"override."
)
metrics.append(
MetricModuleConfig(
name=metric,
params=metric_params,
is_main_metric=metric == self._main_metric,
)
)

if self._metrics and not applied_per_class_override:
logger.warning(
"Ignoring `per_class_metrics` for predefined model metrics "
f"{self._metrics} because none of them support a per-class "
"override."
)
return metrics # noqa: TRY300

return metrics
except Exception:
logger.warning("Unable to ")

return [
MetricModuleConfig(
name=metric,
params=self._metrics_params,
is_main_metric=metric == self._main_metric,
)
for metric in self._metrics
]
Loading