diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index ea1c77192..4043c3649 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -68,7 +68,7 @@ jobs:
cache: pip
- name: Install dependencies
- run: pip install -e .[dev]
+ run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu130
- name: Install dev version of LuxonisML
if: startsWith(github.head_ref, 'release/') == false
@@ -147,7 +147,7 @@ jobs:
cache: pip
- name: Install dependencies
- run: pip install -e .[dev]
+ run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu130
- name: Install dev version of LuxonisML
if: startsWith(github.head_ref, 'release/') == false
diff --git a/README.md b/README.md
index 6aff1ad9d..948e939d2 100644
--- a/README.md
+++ b/README.md
@@ -111,6 +111,15 @@ pip install luxonis-train
This will also install the `luxonis_train` CLI. For more information on how to use it, see [CLI Usage](#cli).
+### AIMET Quantization Support
+
+To enable support for AIMET quantization, install the `luxonis-train[aimet]` extra:
+
+```bash
+pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu130
+
+```
+
## 📝 Usage
@@ -135,6 +144,7 @@ The CLI is the most straightforward way how to use `LuxonisTrain`. The CLI provi
- `tune` - Tune the hyperparameters of the model for better performance
- `inspect` - Inspect the dataset you are using and visualize the annotations
- `annotate` - Annotate a directory using the model’s predictions and generate a new LDF.
+- `quantize` - Quantize the model using `AIMET` quantization techniques
**To get help on any command:**
diff --git a/configs/README.md b/configs/README.md
index ea44c1ecd..6e3e42021 100644
--- a/configs/README.md
+++ b/configs/README.md
@@ -510,6 +510,7 @@ Here you can define configuration for exporting.
| `onnx` | `dict` | `{}` | Options specific for ONNX export. See [ONNX](#onnx) section for details |
| `hubai` | `dict` | `{}` | Options for HubAI SDK conversion. See [HubAI](#hubai) section for details |
| `blobconverter` | `dict` | `{}` | Options for converting to BLOB format (deprecated). See [Blob](#blob-deprecated) section |
+| `aimet` | `dict` | `{}` | Options for AIMET quantization. See [AIMET](#aimet) |
### `ONNX`
@@ -571,6 +572,41 @@ exporter:
shaves: 8
```
+### `AIMET`
+
+The [AIMET](https://quic.github.io/aimet-pages/releases/latest/index.html) (AI Model Efficiency Toolkit) provides quantization and model export tools.
+
+| Key | Type | Default value | Description |
+| -------------------------- | ------------------------------------------------- | -------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `active` | `bool` | `False` | Whether to use AIMET for quantization and export |
+| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training |
+| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights |
+| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters |
+| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values |
+| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use |
+| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/techniques/runtime_config.html) for details on the available options. |
+| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization |
+| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization |
+| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization |
+| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. |
+| `optimizer` | `dict` | `{"name": "SGD", "params": {"lr": 1e-5}}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. |
+| `scheduler` | `dict` | `{"name": "StepLR", "params": {"step_size": 5, "gamma": 0.1}}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples.. |
+| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. |
+
+#### Adaround
+
+Adaptive rounding (AdaRound) is a rounding mechanism for model weights designed to adapt to the data to improve the accuracy of the quantized model.
+
+By default, AIMET uses nearest rounding for quantization, in which weight values are quantized to the nearest integer value. AdaRound, however, uses training data to determine how to round quantized weights. This technique often improves the accuracy of the quantized model.
+
+| Key | Type | Default value | Description |
+| ------------------------ | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| `active` | `bool` | `False` | Whether to use AdaRound for weight rounding during quantization |
+| `default_num_iterations` | `int \| None` | `None` | Number of iterations for the AdaRound optimization. The default value is 10K for models with 8- or higher bit weights, and 15K for models with lower than 8 bit weights. |
+| `default_reg_param` | `float` | `0.01` | Regularization parameter, trading off between rounding loss vs reconstruction loss. |
+| `default_beta_range` | `tuple[int, int]` | `(20, 2)` | Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). |
+| `default_warm_start` | `float` | `0.2` | The warm up period, during which rounding loss has zero effect. |
+
## Tuner
Here you can specify options for tuning.
diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py
index 21de30ebb..271917bfa 100644
--- a/luxonis_train/__main__.py
+++ b/luxonis_train/__main__.py
@@ -154,6 +154,8 @@ def get_visualization_item(
return np_images, np_labels
images, labels = loader[idx]
+ if not isinstance(images, dict):
+ images = {loader.image_source: images}
return (
{
name: image.numpy().transpose(1, 2, 0)
@@ -480,6 +482,29 @@ def convert(
).convert(save_dir=save_dir, weights=weights)
+@app.command(group=export_group, sort_key=1)
+def quantize(
+ opts: list[str] | None = None,
+ /,
+ *,
+ config: str | None = None,
+ weights: str | None = None,
+):
+ """Quantize the model using AIMET.
+
+ @type config: str
+ @param config: Path to the configuration file.
+ @type weights: str
+ @param weights: Path to the model weights.
+ @type opts: list[str]
+ @param opts: A list of optional CLI overrides of the config file.
+ """
+ model = create_model(
+ config, opts, weights=weights, allow_empty_dataset=True
+ )
+ model.quantize()
+
+
@upgrade_app.command()
def config(
config: Annotated[
diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py
index e72ace805..fdcb7cee7 100644
--- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py
+++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py
@@ -29,6 +29,7 @@ class AdaptiveDetectionLoss(BaseLoss):
n_anchors_list: list[int]
stride_tensor: Tensor
gt_bboxes_scale: Tensor
+ anchor_points_strided: Tensor
def __init__(
self,
@@ -102,6 +103,19 @@ def __init__(
self.class_loss_weight = class_loss_weight
self.iou_loss_weight = iou_loss_weight
+ self.register_buffer(
+ "gt_bboxes_scale",
+ torch.tensor(
+ [
+ self.original_img_size[1],
+ self.original_img_size[0],
+ self.original_img_size[1],
+ self.original_img_size[0],
+ ],
+ ),
+ persistent=False,
+ )
+
self._logged_assigner_change = False
def forward(
@@ -163,21 +177,12 @@ def forward(
return loss, sub_losses
def _init_parameters(self, features: list[Tensor]) -> None:
- if not hasattr(self, "gt_bboxes_scale"):
- self.gt_bboxes_scale = torch.tensor(
- [
- self.original_img_size[1],
- self.original_img_size[0],
- self.original_img_size[1],
- self.original_img_size[0],
- ],
- device=features[0].device,
- )
+ if not hasattr(self, "anchors"):
(
- self.anchors,
- self.anchor_points,
- self.n_anchors_list,
- self.stride_tensor,
+ anchors,
+ anchor_points,
+ n_anchors_list,
+ stride_tensor,
) = anchors_for_fpn_features(
features,
self.stride,
@@ -185,8 +190,22 @@ def _init_parameters(self, features: list[Tensor]) -> None:
self.grid_cell_offset,
multiply_with_stride=True,
)
- self.anchor_points_strided = (
- self.anchor_points / self.stride_tensor
+ self.register_buffer("anchors", anchors, persistent=False)
+ self.register_buffer(
+ "anchor_points", anchor_points, persistent=False
+ )
+ self.register_buffer(
+ "n_anchors_list",
+ torch.tensor(n_anchors_list),
+ persistent=False,
+ )
+ self.register_buffer(
+ "stride_tensor", stride_tensor, persistent=False
+ )
+ self.register_buffer(
+ "anchor_points_strided",
+ anchor_points / stride_tensor,
+ persistent=False,
)
def _run_assigner(
diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py
index d535a4af2..dfbdea71f 100644
--- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py
+++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py
@@ -17,8 +17,6 @@
from luxonis_train.utils.boundingbox import IoUType
from luxonis_train.utils.keypoints import insert_class
-from .bce_with_logits import BCEWithLogitsLoss
-
class EfficientKeypointBBoxLoss(AdaptiveDetectionLoss):
node: EfficientKeypointBBoxHead
@@ -74,9 +72,7 @@ def __init__(
**kwargs,
)
- self.b_cross_entropy = BCEWithLogitsLoss(
- pos_weight=torch.tensor([viz_pw])
- )
+ self.pos_weight = torch.tensor([viz_pw])
self.sigmas = get_sigmas(
sigmas=sigmas, n_keypoints=self.n_keypoints, caller_name=self.name
)
@@ -85,6 +81,13 @@ def __init__(
)
self.regr_kpts_loss_weight = regr_kpts_loss_weight
self.vis_kpts_loss_weight = vis_kpts_loss_weight
+ self.register_buffer(
+ "gt_kpts_scale",
+ torch.tensor(
+ [self.original_img_size[1], self.original_img_size[0]],
+ ),
+ persistent=False,
+ )
def forward(
self,
@@ -95,14 +98,14 @@ def forward(
target_boundingbox: Tensor,
target_keypoints: Tensor,
) -> tuple[Tensor, dict[str, Tensor]]:
+ self._init_parameters(features)
+
device = keypoints_raw.device
target_keypoints = insert_class(target_keypoints, target_boundingbox)
batch_size = class_scores.shape[0]
n_kpts = (target_keypoints.shape[1] - 2) // 3
- self._init_parameters(features)
-
pred_bboxes = dist2bbox(distributions, self.anchor_points_strided)
keypoints_raw = self.dist2kpts_noscale(
self.anchor_points_strided,
@@ -124,7 +127,7 @@ def forward(
scaled_raw_keypoints = keypoints_raw.clone()
scaled_raw_keypoints[..., :2] = scaled_raw_keypoints[
..., :2
- ] * self.stride_tensor.view(1, -1, 1, 1)
+ ] * self.stride_tensor.clone().view(1, -1, 1, 1)
sigmas = self.sigmas.to(device)
@@ -190,8 +193,11 @@ def forward(
regression_loss = (
((1 - torch.exp(-e)) * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9)
).mean()
- visibility_loss = self.b_cross_entropy.forward(
- keypoints_raw[..., 2], mask
+
+ visibility_loss = F.binary_cross_entropy_with_logits(
+ keypoints_raw[..., 2],
+ mask,
+ pos_weight=self.pos_weight.clone().to(device),
)
one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[
@@ -264,12 +270,3 @@ def dist2kpts_noscale(self, anchor_points: Tensor, kpts: Tensor) -> Tensor:
adj_kpts[..., 0] += x_adj
adj_kpts[..., 1] += y_adj
return adj_kpts
-
- def _init_parameters(self, features: list[Tensor]) -> None:
- if hasattr(self, "gt_kpts_scale"):
- return
- super()._init_parameters(features)
- self.gt_kpts_scale = torch.tensor(
- [self.original_img_size[1], self.original_img_size[0]],
- device=features[0].device,
- )
diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py
index f71cd7629..8eb5c1ab4 100644
--- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py
+++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py
@@ -92,7 +92,7 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
- window = self.window.to(device)
+ window = self.window.to(device).clone()
else:
window = (
create_window(self.window_size, channel)
diff --git a/luxonis_train/attached_modules/metrics/base_metric.py b/luxonis_train/attached_modules/metrics/base_metric.py
index 8fe67eac9..d0e5cec2c 100644
--- a/luxonis_train/attached_modules/metrics/base_metric.py
+++ b/luxonis_train/attached_modules/metrics/base_metric.py
@@ -164,6 +164,12 @@ def compute(
"""
return super().compute()
+ def __eq__(self, other: object) -> bool:
+ return self is other
+
+ def __hash__(self) -> int:
+ return id(self)
+
@cached_property
def _signature(self) -> dict[str, Parameter]:
return get_signature(self.update)
diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
index 0f4d774eb..35246dfc1 100644
--- a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
+++ b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
@@ -51,6 +51,9 @@ def compute(self) -> dict[str, Tensor]:
}
def _update(self, predictions: list[Tensor], targets: Tensor) -> None:
+ if self.confusion_matrix.is_inference():
+ self.confusion_matrix = self.confusion_matrix.clone()
+
for pred, target in zip(
predictions,
instances_from_batch(targets, batch_size=len(predictions)),
diff --git a/luxonis_train/attached_modules/visualizers/base_visualizer.py b/luxonis_train/attached_modules/visualizers/base_visualizer.py
index ed7f9cc85..565c1e934 100644
--- a/luxonis_train/attached_modules/visualizers/base_visualizer.py
+++ b/luxonis_train/attached_modules/visualizers/base_visualizer.py
@@ -3,8 +3,9 @@
from inspect import Parameter
import torch.nn.functional as F
+from luxonis_ml.data.utils import ColorMap
from torch import Tensor
-from typing_extensions import TypeVarTuple, Unpack
+from typing_extensions import TypeVarTuple, Unpack, override
from luxonis_train.attached_modules import BaseAttachedModule
from luxonis_train.registry import VISUALIZERS
@@ -25,6 +26,13 @@ def __init__(self, *args, scale: float = 1.0, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.scale = scale
+ @override
+ def __getstate__(self) -> dict:
+ state = super().__getstate__()
+ if "colormap" in state:
+ del state["colormap"]
+ return state
+
@staticmethod
def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor:
return F.interpolate(
@@ -34,6 +42,10 @@ def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor:
align_corners=False,
)
+ @cached_property
+ def colormap(self) -> ColorMap:
+ return ColorMap()
+
@abstractmethod
def forward(
self,
diff --git a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py
index 772c632d3..4d3015775 100644
--- a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py
+++ b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py
@@ -3,7 +3,6 @@
import numpy as np
import seaborn as sns
from loguru import logger
-from luxonis_ml.data.utils import ColorMap
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from torch import Tensor
@@ -25,11 +24,10 @@ def __init__(self, z_score_threshold: float = 3, **kwargs):
outliers.
"""
super().__init__(**kwargs)
- self.colors = ColorMap()
self.z_score_threshold = z_score_threshold
def _get_color(self, label: int) -> tuple[float, float, float]:
- r, g, b = self.colors[label]
+ r, g, b = self.colormap[label]
return r / 255, g / 255, b / 255
def forward(
@@ -53,6 +51,7 @@ def forward(
@return: An embedding space projection.
"""
embeddings_np = predictions.detach().cpu().numpy()
+ embeddings_np[np.isnan(embeddings_np) | np.isinf(embeddings_np)] = 0.0
ids_np = target.detach().cpu().numpy().astype(int)
pca = PCA(n_components=2, random_state=42)
@@ -92,8 +91,9 @@ def plot_to_tensor(
plot_func: Callable[[plt.Axes, np.ndarray, np.ndarray], None],
) -> Tensor:
fig, ax = plt.subplots(figsize=(10, 10))
- ax.set_xlim(embeddings_2d[:, 0].min(), embeddings_2d[:, 0].max())
- ax.set_ylim(embeddings_2d[:, 1].min(), embeddings_2d[:, 1].max())
+ if embeddings_2d.size > 0:
+ ax.set_xlim(embeddings_2d[:, 0].min(), embeddings_2d[:, 0].max())
+ ax.set_ylim(embeddings_2d[:, 1].min(), embeddings_2d[:, 1].max())
plot_func(ax, embeddings_2d, ids_np)
ax.axis("off")
diff --git a/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
index 6a457b22b..92eda5d1f 100644
--- a/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
+++ b/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
@@ -2,7 +2,6 @@
import torch
from loguru import logger
-from luxonis_ml.data.utils.visualizations import ColorMap
from torch import Tensor
from typing_extensions import override
@@ -47,7 +46,6 @@ def __init__(
self.background_class = background_class
self.background_color = background_color
self.alpha = alpha
- self.colormap = ColorMap()
self._warn_colors = True
diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md
index 92a9b1272..c45d59b57 100644
--- a/luxonis_train/callbacks/README.md
+++ b/luxonis_train/callbacks/README.md
@@ -186,6 +186,18 @@ A callback that maintains an exponential moving average (EMA) of the model's par
| `use_dynamic_decay` | `bool` | `True` | If enabled, adjusts the decay factor dynamically based on the training iteration. |
| `decay_tau` | `float` | `2000` | The time constant (tau) for dynamic decay, influencing how quickly the EMA adapts. |
+## `AIMETCallback`
+
+Callback to perform AIMET quantization at the end of the training.
+
+This callback runs AIMET post-training static quantization using the best checkpoint (by default based on the main metric) at the end of training.
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped |
+
## `TrainingProgressCallback`
Callback that publishes training progress and timing metrics.
@@ -198,9 +210,8 @@ Callback that publishes training progress and timing metrics.
**Published Metrics:**
-| Metric Key | Description |
-| ------------------------------ | ------------------------------------------------------- |
-| `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed |
-| `train/epoch_duration_sec` | Time elapsed so far in current epoch |
-| `train/epoch_completion_sec` | Total duration of completed training epoch in seconds |
-| `val/epoch_completion_sec` | Total duration of completed validation epoch in seconds |
+| Metric Key | Description |
+| ------------------------------ | --------------------------------------------- |
+| `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed |
+| `train/epoch_duration_sec` | Time elapsed so far in current epoch |
+| `train/epoch_completion_sec` | Total duration of completed epoch in seconds |
diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py
index 3b29218fc..3fbde3ff7 100644
--- a/luxonis_train/callbacks/__init__.py
+++ b/luxonis_train/callbacks/__init__.py
@@ -11,6 +11,7 @@
from luxonis_train.registry import CALLBACKS
+from .aimet_callback import AIMETCallback
from .archive_on_train_end import ArchiveOnTrainEnd
from .convert_on_train_end import ConvertOnTrainEnd
from .ema import EMACallback
@@ -50,6 +51,7 @@
__all__ = [
+ "AIMETCallback",
"ArchiveOnTrainEnd",
"BaseLuxonisProgressBar",
"ConvertOnTrainEnd",
diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py
new file mode 100644
index 000000000..99d59bd11
--- /dev/null
+++ b/luxonis_train/callbacks/aimet_callback.py
@@ -0,0 +1,16 @@
+import lightning.pytorch as pl
+
+import luxonis_train as lxt
+from luxonis_train.callbacks.needs_checkpoint import NeedsCheckpoint
+from luxonis_train.registry import CALLBACKS
+
+
+@CALLBACKS.register()
+class AIMETCallback(NeedsCheckpoint):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def on_train_end(
+ self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule"
+ ) -> None:
+ pl_module.core.quantize(self.get_checkpoint(pl_module))
diff --git a/luxonis_train/callbacks/gradcam_visualizer.py b/luxonis_train/callbacks/gradcam_visualizer.py
index ed26b08ed..0f23ff77c 100644
--- a/luxonis_train/callbacks/gradcam_visualizer.py
+++ b/luxonis_train/callbacks/gradcam_visualizer.py
@@ -47,7 +47,7 @@ def forward(self, inputs: Tensor, *args, **kwargs) -> Tensor:
@return: The processed output based on the task type.
"""
input_dict = {"image": inputs}
- output = self.pl_module(input_dict, *args, **kwargs)
+ output = self.pl_module.full_forward(input_dict, *args, **kwargs)
if len(output.outputs) > 1:
logger.warning(
"Model has multiple heads. Using the first head for Grad-CAM."
@@ -142,7 +142,9 @@ def on_validation_batch_end(
@param batch_idx: The index of the batch.
"""
if batch_idx < self.log_n_batches:
- images = batch[0][pl_module.image_source]
+ images = batch[0]
+ if isinstance(images, dict):
+ images = images[pl_module.image_source]
self.visualize_gradients(trainer, pl_module, images, batch_idx)
def visualize_gradients(
diff --git a/luxonis_train/callbacks/luxonis_progress_bar.py b/luxonis_train/callbacks/luxonis_progress_bar.py
index fa05e8f0a..731e6e49c 100644
--- a/luxonis_train/callbacks/luxonis_progress_bar.py
+++ b/luxonis_train/callbacks/luxonis_progress_bar.py
@@ -1,6 +1,6 @@
import time
from abc import ABC, abstractmethod
-from collections.abc import Mapping
+from collections.abc import Iterable, Mapping
from io import StringIO
from typing import Any
@@ -59,6 +59,24 @@ def print_results(
"""
...
+ @abstractmethod
+ def print_table(
+ self,
+ title: str,
+ table: Iterable[tuple[str | int | float, ...]],
+ column_names: list[str],
+ ) -> None:
+ """Prints table to the console.
+
+ @type title: str
+ @param title: Title of the table
+ @type table: Iterable[tuple[str | int | float, ...]]
+ @param table: Table to print as an iterable of rows, where each row is a tuple of values.
+ @type column_names: list[str]
+ @param column_names: Names of the columns in the table
+ """
+ ...
+
def _log_progress(self, trainer: pl.Trainer) -> None:
duration = (
time.time() - self._epoch_start_time
@@ -129,7 +147,9 @@ def print_results(
logger.info(f"Loss: {loss}")
logger.info("Metrics:")
for table_name, table in metrics.items():
- self._print_table(table_name, table)
+ self.print_table(
+ table_name, list(table.items()), ["Name", "Value"]
+ )
for matrix_name, matrix in matrices.get(table_name, {}).items():
self._print_matrix(
self._format_matrix_title(matrix_name), matrix
@@ -150,29 +170,26 @@ def _rule(self, title: str | None = None) -> None:
else:
logger.info("-----------------")
- def _print_table(
+ @override
+ def print_table(
self,
title: str,
- table: Mapping[str, int | str | float],
- key_name: str = "Name",
- value_name: str = "Value",
+ table: Iterable[tuple[str | int | float, ...]],
+ column_names: list[str],
) -> None:
"""Prints table to the console using tabulate.
@type title: str
@param title: Title of the table
- @type table: Mapping[str, int | str | float]
- @param table: Table to print
- @type key_name: str
- @param key_name: Name of the key column. Defaults to C{"Name"}.
- @type value_name: str
- @param value_name: Name of the value column. Defaults to
- C{"Value"}.
+ @type table: Iterable[tuple[str | int | float, ...]]
+ @param table: Table to print as an iterable of rows, where each row is a tuple of values.
+ @type column_names: list[str]
+ @param column_names: Names of the columns in the table
"""
self._rule(title)
formatted = tabulate(
- table.items(),
- headers=[key_name, value_name],
+ table,
+ headers=column_names,
tablefmt="fancy_grid",
numalign="right",
)
@@ -250,7 +267,9 @@ def print_results(
)
self.console.print("[bold magenta]Metrics:[/bold magenta]")
for table_name, table in metrics.items():
- self._print_table(table_name, table)
+ self.print_table(
+ table_name, list(table.items()), ["Name", "Value"]
+ )
for matrix_name, matrix in matrices.get(table_name, {}).items():
self._print_matrix(
self._format_matrix_title(matrix_name), matrix
@@ -270,7 +289,12 @@ def print_results(
self._log_console.print(f"Loss: {loss}")
self._log_console.print("Metrics:")
for table_name, table in metrics.items():
- self._print_table(table_name, table, console=self._log_console)
+ self.print_table(
+ table_name,
+ list(table.items()),
+ ["Name", "Value"],
+ console=self._log_console,
+ )
for matrix_name, matrix in matrices.get(table_name, {}).items():
self._print_matrix(
self._format_matrix_title(matrix_name),
@@ -293,25 +317,22 @@ def print_results(
self._log_buffer.seek(0)
self._log_buffer.truncate(0)
- def _print_table(
+ @override
+ def print_table(
self,
title: str,
- table: Mapping[str, int | str | float],
- key_name: str = "Name",
- value_name: str = "Value",
+ table: Iterable[tuple[str | int | float, ...]],
+ column_names: list[str],
console: Console | None = None,
) -> None:
"""Prints table to the console using rich text.
@type title: str
@param title: Title of the table
- @type table: Mapping[str, int | str | float]
- @param table: Table to print
- @type key_name: str
- @param key_name: Name of the key column. Defaults to C{"Name"}.
- @type value_name: str
- @param value_name: Name of the value column. Defaults to
- C{"Value"}.
+ @type table: Iterable[tuple[str | int | float, ...]]
+ @param table: Table to print as an iterable of rows, where each row is a tuple of values.
+ @type column_names: list[str]
+ @param column_names: Names of the columns in the table
@param console: Console instance to use, if None use default
console. Defaults to None.
@type console: Console | None
@@ -323,10 +344,18 @@ def _print_table(
header_style="bold magenta",
title_style="bold",
)
- rich_table.add_column(key_name, style="magenta")
- rich_table.add_column(value_name, style="white")
- for name, value in table.items():
- rich_table.add_row(name, f"{value:.5f}")
+ for i, column_name in enumerate(column_names):
+ rich_table.add_column(
+ column_name, style="magenta" if i == 0 else "white"
+ )
+ for name, *values in table:
+ rich_table.add_row(
+ str(name),
+ *[
+ f"{value:.5f}" if isinstance(value, float) else str(value)
+ for value in values
+ ],
+ )
console.print(rich_table)
def _print_matrix(
diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py
index 70f9e8eac..094242ebc 100644
--- a/luxonis_train/config/config.py
+++ b/luxonis_train/config/config.py
@@ -1,6 +1,8 @@
+import json
import sys
from collections.abc import Mapping
from contextlib import suppress
+from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Literal, NamedTuple
@@ -725,6 +727,87 @@ def _validate_quantization_mode(value: str) -> str:
return value
+class AdaroundConfig(BaseModelExtraForbid):
+ active: bool = False
+ default_num_iterations: PositiveInt | None = None
+ default_reg_param: float = 0.01
+ default_beta_range: tuple[int, int] = (20, 2)
+ default_warm_start: float = 0.2
+
+
+class AIMETConfig(BaseModelExtraForbid):
+ active: bool = False
+
+ default_output_bw: Literal[4, 8, 16] = 8
+ default_param_bw: Literal[4, 8, 16] = 8
+ default_data_type: Literal["int", "float"] = "int"
+ quant_scheme: Literal["min_max", "tf", "tf_enhanced"] = "min_max"
+ config: Params | None = None
+
+ fold_batch_norms: bool = False
+ cross_layer_equalization: bool = False
+ batch_norm_reestimation: bool = False
+ sequential_mse: bool = False
+ adaround: AdaroundConfig = Field(default_factory=AdaroundConfig)
+
+ epochs: PositiveInt = 20
+ optimizer: ConfigItem = Field(
+ default_factory=lambda: ConfigItem(name="SGD", params={"lr": 1e-5})
+ )
+ scheduler: ConfigItem = Field(
+ default_factory=lambda: ConfigItem(
+ name="StepLR", params={"step_size": 5, "gamma": 0.1}
+ )
+ )
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_active(cls, data: Params) -> Params:
+ if not data.get("active", False):
+ return data
+ for required_field in [
+ "fold_batch_norms",
+ "cross_layer_equalization",
+ "batch_norm_reestimation",
+ "sequential_mse",
+ ]:
+ if required_field not in data:
+ raise ValueError(
+ f"AIMET config is active but missing required field '{required_field}'."
+ )
+ adaround = data.get("adaround", {})
+ if not isinstance(adaround, dict):
+ raise TypeError(
+ f"Invalid type for 'adaround': {type(adaround)}. "
+ "Expected a dict."
+ )
+ if "active" not in adaround:
+ raise ValueError(
+ "AIMET config is active but missing required field "
+ "'adaround.active'."
+ )
+ return data
+
+ @field_validator("config", mode="before")
+ @classmethod
+ def validate_config(cls, value: ParamValue) -> Any:
+ if isinstance(value, str):
+ try:
+ fs = LuxonisFileSystem(value)
+ return json.loads(fs.read_text(""))
+ except Exception as e:
+ raise ValueError(
+ f"Failed to load AIMET config from file '{value}': {e}"
+ ) from e
+ return value
+
+ @field_serializer("default_data_type", "quant_scheme")
+ def serialize_enums(self, value: Any) -> str:
+ if isinstance(value, Enum):
+ return value.name
+ return value
+
+
class ExportConfig(ArchiveConfig):
name: str | None = None
input_shape: list[int] | None = None
@@ -741,6 +824,7 @@ class ExportConfig(ArchiveConfig):
default_factory=BlobconverterExportConfig
)
hubai: HubAIExportConfig = Field(default_factory=HubAIExportConfig)
+ aimet: AIMETConfig = Field(default_factory=AIMETConfig)
@field_validator("scale_values", "mean_values", mode="before")
@classmethod
diff --git a/luxonis_train/config/predefined_models/base_predefined_model.py b/luxonis_train/config/predefined_models/base_predefined_model.py
index f8a820e73..7e90d2e21 100644
--- a/luxonis_train/config/predefined_models/base_predefined_model.py
+++ b/luxonis_train/config/predefined_models/base_predefined_model.py
@@ -116,7 +116,7 @@ def __init__(
self._metrics = (
[metrics] if isinstance(metrics, str) else metrics or []
)
- if main_metric is None:
+ if main_metric is None and self._metrics:
if len(self._metrics) == 1:
main_metric = self._metrics[0]
else:
diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py
index 93d4270df..5fb4ea64b 100644
--- a/luxonis_train/core/core.py
+++ b/luxonis_train/core/core.py
@@ -1,9 +1,11 @@
+import json
import tempfile
import threading
from collections.abc import Mapping
+from copy import deepcopy
from pathlib import Path
-from threading import ExceptHookArgs
-from typing import Any, Literal, overload
+from threading import ExceptHookArgs, Thread
+from typing import Any, Literal, cast, overload
import lightning.pytorch as pl
import lightning_utilities.core.rank_zero as rank_zero_module
@@ -20,6 +22,9 @@
from luxonis_ml.nn_archive.config import CONFIG_VERSION
from luxonis_ml.typing import Params, PathType
from luxonis_ml.utils import Environ, LuxonisFileSystem
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from torch.utils.data.dataloader import DataLoader
from typeguard import typechecked
from luxonis_train.callbacks import (
@@ -36,7 +41,13 @@
DummyLoader,
LuxonisLoaderTorch,
)
-from luxonis_train.registry import LOADERS
+from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput
+from luxonis_train.registry import (
+ LOADERS,
+ OPTIMIZERS,
+ SCHEDULERS,
+ from_registry,
+)
from luxonis_train.typing import View
from luxonis_train.utils import (
DatasetMetadata,
@@ -81,8 +92,8 @@ def __init__(
cfg: PathType | Params | Config | None,
opts: Params | list[str] | tuple[str, ...] | None = None,
*,
- allow_empty_dataset: bool = False,
weights: PathType | dict[str, Any] | None = None,
+ allow_empty_dataset: bool = False,
dataset_metadata: DatasetMetadata | None = None,
):
"""Constructs a new Core instance.
@@ -266,7 +277,10 @@ def __init__(
"Weighted sampler is not implemented yet."
)
- self.pytorch_loaders: dict[View, torch_data.DataLoader] = {}
+ self.pytorch_loaders: dict[
+ View,
+ torch_data.DataLoader[LuxonisLoaderTorchOutput],
+ ] = {}
for view in ("train", "val", "test"):
if self.cfg.trainer.n_validation_batches is not None and view in {
"val",
@@ -348,8 +362,30 @@ def __init__(
_core=self,
)
+ if weights is not None:
+ weights = LuxonisFileSystem.download(
+ str(weights), self.run_save_dir
+ )
+ if self.cfg.model.weights is not None:
+ logger.warning(
+ "Weights provided in the command line, but config weights are set. "
+ "Ignoring weights provided in config."
+ )
+ self.lightning_module.load_checkpoint(weights)
self._exported_models: dict[str, Path] = {}
+ @property
+ def train_loader(self) -> DataLoader:
+ return self.pytorch_loaders["train"]
+
+ @property
+ def val_loader(self) -> DataLoader:
+ return self.pytorch_loaders["val"]
+
+ @property
+ def test_loader(self) -> DataLoader:
+ return self.pytorch_loaders["test"]
+
def save_checkpoint(
self,
path: PathType,
@@ -446,8 +482,8 @@ def train(
self._train(
resume_weights,
self.lightning_module,
- self.pytorch_loaders["train"],
- self.pytorch_loaders["val"],
+ self.train_loader,
+ self.val_loader,
)
logger.info("Training finished")
logger.info(f"Checkpoints saved in: {self.run_save_dir}")
@@ -464,8 +500,8 @@ def thread_exception_hook(args: ExceptHookArgs) -> None:
args=(
resume_weights,
self.lightning_module,
- self.pytorch_loaders["train"],
- self.pytorch_loaders["val"],
+ self.train_loader,
+ self.val_loader,
),
daemon=True,
)
@@ -498,6 +534,9 @@ def export(
This is useful for updating the metadata in the checkpoint
file in case they changed (e.g. new configuration file,
architectural changes affecting the exection order etc.)
+
+ @rtype: Path
+ @return: Path to the exported ONNX model file or .ckpt file if ckpt_only is True.
"""
weights = self.resolve_weights(weights)
@@ -630,7 +669,7 @@ def test(
new_thread: Literal[True] = ...,
view: Literal["train", "test", "val"] = "test",
weights: PathType | dict[str, Any] | None = ...,
- ) -> None: ...
+ ) -> Thread: ...
@typechecked
def test(
@@ -638,22 +677,22 @@ def test(
new_thread: bool = False,
view: Literal["train", "val", "test"] = "test",
weights: PathType | dict[str, Any] | None = None,
- ) -> Mapping[str, float] | None:
+ ) -> Mapping[str, float] | Thread:
"""Runs testing.
@type new_thread: bool
@param new_thread: Runs testing in a new thread if set to True.
@type view: Literal["train", "test", "val"]
@param view: Which view to run the testing on. Defauls to "test".
- @rtype: Mapping[str, float] | None
- @return: If new_thread is False, returns a dictionary test
- results.
@type weights: PathType | None
@param weights: Path to the checkpoint from which to load weights.
If not specified, the value of `model.weights` from the
configuration file will be used. The current weights of the
model will be temporarily replaced with the weights from the
specified checkpoint.
+ @rtype: Mapping[str, float] | Thread
+ @return: If new_thread is False, returns a dictionary test
+ results.
"""
weights = self.resolve_weights(weights)
loader = self.pytorch_loaders[view]
@@ -665,7 +704,8 @@ def test(
args=(self.lightning_module, loader),
daemon=True,
)
- return self.thread.start()
+ self.thread.start()
+ return self.thread
return self.pl_trainer.test(self.lightning_module, loader)[0]
def infer(
@@ -897,8 +937,8 @@ def _objective(trial: optuna.trial.Trial) -> float:
try:
pl_trainer.fit(
lightning_module,
- self.pytorch_loaders["train"],
- self.pytorch_loaders["val"],
+ self.train_loader,
+ self.val_loader,
)
pruner_callback.check_pruned()
@@ -1242,6 +1282,252 @@ def convert(
return archive_path, conversion_artifacts
+ def quantize(
+ self,
+ weights: PathType | None = None,
+ epochs: int | None = None,
+ quant_scheme: Literal["min_max", "tf", "tf_enhanced"] | None = None,
+ default_output_bw: int | None = None,
+ default_param_bw: int | None = None,
+ config_file: str | None = None,
+ default_data_type: Literal["int", "float"] | None = None,
+ adaround: bool | None = None,
+ adaround_iterations: int | None = None,
+ adaround_reg_param: float | None = None,
+ adaround_beta_range: tuple[int, int] | None = None,
+ adaround_warm_start: float | None = None,
+ fold_batch_norms: bool | None = None,
+ cross_layer_equalization: bool | None = None,
+ batch_norm_reestimation: bool | None = None,
+ sequential_mse: bool | None = None,
+ optimizer: Optimizer | None = None,
+ scheduler: LRScheduler | None = None,
+ in_place: bool = False,
+ ) -> Path:
+ """Runs quantization of the model using AIMET.
+
+ @type weights: PathType | None
+ @param weights: Path to the checkpoint from which to load
+ weights.
+ @type epochs: int | None
+ @param epochs: Number of epochs to run quantization-aware
+ training for.
+ @type quant_scheme: str | QuantScheme | None
+ @param quant_scheme: Quantization scheme to use. If not
+ specified, the value from the configuration file will be
+ used.
+ @type default_output_bw: int | None
+ @param default_output_bw: Default bitwidth to use for quantizing
+ outputs. If not specified, the value from the configuration
+ file will be used.
+ @type default_param_bw: int | None
+ @param default_param_bw: Default bitwidth to use for quantizing
+ parameters. If not specified, the value from the
+ configuration file will be used.
+ @type config_file: str | None
+ @param config_file: Path to the AIMET configuration file or a
+ dictionary containing the AIMET configuration. If not
+ specified, the value from the configuration file will be
+ used.
+ @type default_data_type: QuantizationDataType | None
+ @param default_data_type: Default data type to use for
+ quantization. If not specified, the value from the
+ configuration file will be used.
+ @type adaround: bool | None
+ @param adaround: Whether to use Adaround for weight
+ quantization. If not specified, the value from the
+ configuration file will be used.
+ @type adaround_iterations: int | None
+ @param adaround_iterations: Number of iterations to run Adaround
+ for. If not specified, the value from the configuration file
+ will be used.
+ @type adaround_reg_param: float | None
+ @param adaround_reg_param: Regularization parameter to use for
+ Adaround. If not specified, the value from the configuration
+ file will be used.
+ @type adaround_beta_range: tuple[int, int] | None
+ @param adaround_beta_range: Beta range to use for Adaround. If
+ not specified, the value from the configuration file will be
+ used.
+ @type adaround_warm_start: float | None = None
+ @param adaround_warm_start: Warm start value to use for
+ Adaround. If not specified, the value from the configuration
+ file will be used.
+ @type fold_batch_norms: bool | None
+ @param fold_batch_norms: Whether to fold batch norms before
+ quantization. If not specified, the value from the
+ configuration file will be used.
+ @type cross_layer_equalization: bool | None
+ @param cross_layer_equalization: Whether to perform cross-layer
+ equalization before quantization. If not specified, the
+ value from the configuration file will be used.
+ @type batch_norm_reestimation: bool | None
+ @param batch_norm_reestimation: Whether to perform batch norm
+ reestimation after folding batch norms. If not specified,
+ the value from the configuration file will be used.
+ @type optimizer: Optimizer | None
+ @param optimizer: Optimizer to use for quantization-aware
+ training. If not specified, the optimizer from the
+ configuration file will be used.
+ @type scheduler: LRScheduler | None
+ @param scheduler: Learning rate scheduler to use for
+ quantization-aware training. If not specified, the scheduler
+ from the configuration file will be used.
+ @type in_place: bool
+ @param in_place: Whether to perform quantization in-place on the
+ original model or to create a copy of the model for
+ quantization. Defaults to False, which means that a copy of
+ the model will be created for quantization. Setting this to
+ True will modify the original model in-place, which may save
+ memory but will overwrite the original model's weights and
+ structure.
+ """
+ from aimet_torch.common.defs import QuantizationDataType, QuantScheme
+
+ from .utils.aimet_utils import (
+ post_training_quantization,
+ quantization_aware_training,
+ )
+
+ save_dir = self.run_save_dir / "aimet"
+ save_dir.mkdir(parents=True, exist_ok=True)
+
+ cfg = self.cfg.exporter.aimet
+
+ aimet_config_file = config_file or cfg.config
+ if isinstance(aimet_config_file, dict):
+ with open(save_dir / "aimet_config.json", "w") as f:
+ json.dump(aimet_config_file, f, indent=4)
+ aimet_config_file = str(save_dir / "aimet_config.json")
+
+ adaround = adaround if adaround is not None else cfg.adaround.active
+ fold_batch_norms = (
+ fold_batch_norms
+ if fold_batch_norms is not None
+ else cfg.fold_batch_norms
+ )
+ cross_layer_equalization = (
+ cross_layer_equalization
+ if cross_layer_equalization is not None
+ else cfg.cross_layer_equalization
+ )
+ batch_norm_reestimation = (
+ batch_norm_reestimation
+ if batch_norm_reestimation is not None
+ else cfg.batch_norm_reestimation
+ )
+ sequential_mse = (
+ sequential_mse
+ if sequential_mse is not None
+ else cfg.sequential_mse
+ )
+
+ if not in_place:
+ model = deepcopy(self.lightning_module)
+ else:
+ model = self.lightning_module
+
+ model.reparametrize().eval()
+
+ if weights is not None:
+ model.load_checkpoint(weights)
+
+ pre_quant_test = self.pl_trainer.test(model, self.val_loader)[0]
+
+ dummy_inputs = {
+ input_name: torch.randn([1, *shape]).to(model.device)
+ for shapes in model.nodes.loader_input_shapes.values()
+ for input_name, shape in shapes.items()
+ }
+
+ if len(dummy_inputs) > 1:
+ raise NotImplementedError(
+ "Quantization is not yet supported for models "
+ "with multiple inputs."
+ )
+ input_names = list(dummy_inputs.keys())
+ output_names = model._get_output_onnx_names(deepcopy(dummy_inputs))
+ dummy_inputs = next(iter(dummy_inputs.values()))
+
+ sim = post_training_quantization(
+ model,
+ dummy_inputs,
+ self.val_loader,
+ save_dir,
+ QuantScheme.from_str(quant_scheme or cfg.quant_scheme),
+ default_output_bw or cfg.default_output_bw,
+ default_param_bw or cfg.default_param_bw,
+ QuantizationDataType[default_data_type or cfg.default_data_type],
+ aimet_config_file,
+ adaround,
+ adaround_iterations
+ if adaround_iterations is not None
+ else cfg.adaround.default_num_iterations,
+ adaround_reg_param
+ if adaround_reg_param is not None
+ else cfg.adaround.default_reg_param,
+ adaround_beta_range or cfg.adaround.default_beta_range,
+ adaround_warm_start
+ if adaround_warm_start is not None
+ else cfg.adaround.default_warm_start,
+ fold_batch_norms,
+ cross_layer_equalization,
+ batch_norm_reestimation,
+ sequential_mse,
+ )
+ model = cast(LuxonisLightningModule, sim.model)
+
+ model.eval()
+ ptq_test = self.pl_trainer.test(model, self.val_loader)[0]
+
+ if optimizer is None:
+ optimizer = from_registry(
+ OPTIMIZERS,
+ cfg.optimizer.name,
+ params=sim.model.parameters(),
+ **cfg.optimizer.params,
+ )
+ if scheduler is None:
+ scheduler = from_registry(
+ SCHEDULERS,
+ cfg.scheduler.name,
+ optimizer=optimizer,
+ **cfg.scheduler.params,
+ )
+
+ model = quantization_aware_training(
+ sim,
+ dummy_inputs,
+ self.train_loader,
+ optimizer,
+ scheduler,
+ epochs if epochs is not None else cfg.epochs,
+ fold_batch_norms,
+ batch_norm_reestimation,
+ ).eval()
+
+ qat_test = self.pl_trainer.test(model, self.val_loader)[0]
+
+ model.set_export_mode(mode=True)
+
+ sim.onnx.export(
+ dummy_inputs,
+ (save_dir / self.cfg.model.name).with_suffix(".onnx"),
+ input_names=input_names,
+ output_names=output_names,
+ )
+
+ table = []
+ for key, value in pre_quant_test.items():
+ log_key = key.replace("test/metric/", "").replace("test/loss/", "")
+ table.append((log_key, value, ptq_test[key], qat_test[key]))
+ model.progress_bar.print_table(
+ "Quantization results",
+ table,
+ ["Name", "Pre-Quant", "PTQ", "QAT"],
+ )
+ return save_dir
+
@property
def environ(self) -> Environ:
return self.cfg.ENVIRON
diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py
new file mode 100644
index 000000000..8c55c3634
--- /dev/null
+++ b/luxonis_train/core/utils/aimet_utils.py
@@ -0,0 +1,193 @@
+import math
+from importlib.util import find_spec
+from pathlib import Path
+from typing import Any, cast
+
+from aimet_torch import QuantizationSimModel
+from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
+from aimet_torch.batch_norm_fold import fold_all_batch_norms
+from aimet_torch.bn_reestimation import reestimate_bn_stats
+from aimet_torch.common.defs import QuantizationDataType, QuantScheme
+from aimet_torch.common.quantsim_config.utils import (
+ get_path_for_per_channel_config,
+)
+from aimet_torch.cross_layer_equalization import equalize_model
+from aimet_torch.seq_mse import apply_seq_mse
+from lightning.pytorch.accelerators import CUDAAccelerator
+from loguru import logger
+from rich.progress import track
+from torch import Tensor, nn
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LRScheduler
+from torch.utils.data import DataLoader
+
+from luxonis_train.lightning import LuxonisLightningModule
+from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput
+
+
+def check_aimet_available() -> None:
+ if not find_spec("aimet_torch"):
+ raise ImportError(
+ "AIMET library is not installed. Please install "
+ "`luxonis-train` with the `aimet` extra enabled "
+ "(pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu126)"
+ )
+
+
+def post_training_quantization(
+ model: LuxonisLightningModule,
+ dummy_inputs: Tensor,
+ val_loader: DataLoader,
+ save_dir: Path,
+ quant_scheme: str | QuantScheme = QuantScheme.min_max,
+ default_output_bw: int = 8,
+ default_param_bw: int = 8,
+ default_data_type: QuantizationDataType = QuantizationDataType.int,
+ config_file: str | None = None,
+ adaround: bool = False,
+ adaround_iterations: int | None = None,
+ adaround_reg_param: float = 0.01,
+ adaround_beta_range: tuple[int, int] = (20, 2),
+ adaround_warm_start: float = 0.2,
+ fold_batch_norms: bool = False,
+ cross_layer_equalization: bool = False,
+ batch_norm_reestimation: bool = False,
+ sequential_mse: bool = False,
+) -> QuantizationSimModel:
+
+ def pass_calibration_data(model: nn.Module) -> None:
+ assert len(val_loader) > 0, (
+ "Validation loader must have at least one batch"
+ )
+ for imgs, _ in track(
+ val_loader,
+ description="Computing quantization encodings",
+ total=len(val_loader),
+ ):
+ model.forward(imgs)
+
+ if CUDAAccelerator.is_available():
+ dummy_inputs = dummy_inputs.cuda()
+ model.cuda()
+
+ model.eval()
+
+ if fold_batch_norms and not batch_norm_reestimation:
+ logger.info("Folding batch norms into preceding layers")
+ fold_all_batch_norms(
+ model, input_shapes=dummy_inputs.shape, dummy_input=dummy_inputs
+ )
+ if cross_layer_equalization:
+ logger.info("Applying cross-layer equalization")
+ equalize_model(
+ model, input_shapes=dummy_inputs.shape, dummy_input=dummy_inputs
+ )
+
+ if adaround:
+ ada_params = AdaroundParameters(
+ data_loader=val_loader,
+ num_batches=min(
+ len(val_loader),
+ math.ceil(2000 / val_loader.batch_size), # type: ignore
+ ),
+ default_num_iterations=adaround_iterations, # type: ignore
+ default_reg_param=adaround_reg_param,
+ default_beta_range=adaround_beta_range,
+ default_warm_start=adaround_warm_start,
+ )
+ model = cast(
+ LuxonisLightningModule,
+ Adaround.apply_adaround(
+ model,
+ dummy_inputs,
+ ada_params,
+ path=str(save_dir),
+ filename_prefix="adaround",
+ ),
+ )
+
+ if batch_norm_reestimation and config_file is None:
+ config_file = get_path_for_per_channel_config()
+
+ sim = QuantizationSimModel(
+ model=model,
+ dummy_input=dummy_inputs,
+ quant_scheme=quant_scheme,
+ default_output_bw=default_output_bw,
+ default_param_bw=default_param_bw,
+ config_file=config_file,
+ default_data_type=default_data_type,
+ in_place=True,
+ )
+ if sequential_mse:
+ logger.info("Applying sequential MSE")
+
+ apply_seq_mse(
+ sim,
+ data_loader=val_loader,
+ num_candidates=20,
+ forward_fn=_patched_forward_pass,
+ )
+
+ if adaround:
+ sim.set_and_freeze_param_encodings(
+ str(save_dir / "adaround.encodings")
+ )
+
+ sim.compute_encodings(pass_calibration_data)
+ return sim
+
+
+def quantization_aware_training(
+ sim: QuantizationSimModel,
+ dummy_inputs: Tensor,
+ train_loader: DataLoader,
+ optimizer: Optimizer,
+ scheduler: LRScheduler,
+ epochs: int,
+ fold_batch_norms: bool = False,
+ batch_norm_reestimation: bool = False,
+) -> LuxonisLightningModule:
+
+ model = cast(LuxonisLightningModule, sim.model)
+
+ model.train()
+ if CUDAAccelerator.is_available():
+ model.cuda()
+ model.automatic_optimization = False
+
+ for _ in track(
+ range(epochs),
+ description="Running Quantization-Aware Training",
+ total=epochs,
+ ):
+ for imgs, labels in train_loader:
+ optimizer.zero_grad()
+ loss = model.training_step((imgs, labels))
+ loss.backward()
+ optimizer.step()
+ scheduler.step()
+
+ if batch_norm_reestimation:
+ logger.info("Reestimating batch norm statistics")
+
+ reestimate_bn_stats(
+ model, train_loader, forward_fn=_patched_forward_pass
+ )
+
+ if fold_batch_norms:
+ logger.info("Folding batch norms into preceding layers")
+ fold_all_batch_norms(
+ model,
+ input_shapes=dummy_inputs.shape,
+ dummy_input=dummy_inputs,
+ )
+
+ model.automatic_optimization = True
+ return model
+
+
+def _patched_forward_pass(
+ model: nn.Module, inputs: LuxonisLoaderTorchOutput
+) -> Any:
+ return model(inputs[0])
diff --git a/luxonis_train/core/utils/annotate_utils.py b/luxonis_train/core/utils/annotate_utils.py
index 184688aef..a78748593 100644
--- a/luxonis_train/core/utils/annotate_utils.py
+++ b/luxonis_train/core/utils/annotate_utils.py
@@ -82,7 +82,7 @@ def annotated_dataset_generator(
for imgs, metas in loader:
with torch.no_grad():
- batch_out = lt_module(imgs).outputs
+ batch_out = lt_module.full_forward(imgs).outputs
for head_name, head_output in batch_out.items():
img_paths = [Path(p) for p in metas["/metadata/path"]]
diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py
index d44e341ea..7e8d23441 100644
--- a/luxonis_train/core/utils/infer_utils.py
+++ b/luxonis_train/core/utils/infer_utils.py
@@ -59,8 +59,8 @@ def prepare_and_infer_image(
npy_img = model.loaders["val"].augment_test_image(images)
torch_img = torch.tensor(npy_img).unsqueeze(0).permute(0, 3, 1, 2).float()
- return model.lightning_module.forward(
- {"image": torch_img},
+ return model.lightning_module.full_forward(
+ {model.lightning_module.image_source: torch_img},
images=get_denormalized_images(model.cfg, torch_img),
compute_visualizations=True,
)
diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py
index a14d69ac3..1967eab36 100644
--- a/luxonis_train/lightning/luxonis_lightning.py
+++ b/luxonis_train/lightning/luxonis_lightning.py
@@ -1,5 +1,6 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
+from copy import deepcopy
from pathlib import Path
from typing import Any, Literal, cast
@@ -13,7 +14,7 @@
from semver import Version
from torch import Size, Tensor
from torch.nn.modules.module import _IncompatibleKeys
-from typing_extensions import override
+from typing_extensions import Self, override
import luxonis_train
from luxonis_train.attached_modules.visualizers import (
@@ -23,6 +24,8 @@
from luxonis_train.callbacks import BaseLuxonisProgressBar
from luxonis_train.config import Config
from luxonis_train.nodes import BaseNode
+from luxonis_train.nodes.blocks.reparametrizable import Reparametrizable
+from luxonis_train.registry import _INTERNAL
from luxonis_train.typing import Labels, Packet
from luxonis_train.utils import DatasetMetadata, LuxonisTrackerPL
from luxonis_train.utils.checkpoint import filter_checkpoint_state_dict
@@ -94,7 +97,7 @@ class LuxonisLightningModule(pl.LightningModule):
_trainer: pl.Trainer
logger: LuxonisTrackerPL
- __call__: Callable[..., LuxonisOutput]
+ __call__: Callable[..., tuple[Tensor, ...]]
def __init__(
self,
@@ -191,8 +194,43 @@ def core(self) -> "luxonis_train.core.LuxonisModel":
@override
def forward(
+ self, inputs: dict[str, Tensor] | Tensor
+ ) -> tuple[Tensor, ...]:
+ """Forward pass of the model.
+
+ @type inputs: L{Tensor}
+ @param inputs: Input tensors.
+ @rtype: dict[str, L{Packet}[L{Tensor}]]
+ @return: Output of the model.
+ """
+ outputs = self.full_forward(
+ inputs,
+ compute_loss=False,
+ compute_metrics=False,
+ compute_visualizations=False,
+ ).outputs
+
+ output_order = sorted(
+ [
+ (node_name, output_name, i)
+ for node_name, outs in outputs.items()
+ for output_name, out in outs.items()
+ for i in range(len(out))
+ ]
+ )
+ new_outputs = []
+ for node_name, output_name, i in output_order:
+ node_output = outputs[node_name][output_name]
+ if isinstance(node_output, Tensor):
+ new_outputs.append(node_output)
+ else:
+ new_outputs.append(node_output[i])
+
+ return tuple(new_outputs)
+
+ def full_forward(
self,
- inputs: dict[str, Tensor],
+ inputs: dict[str, Tensor] | Tensor,
labels: Labels | None = None,
images: Tensor | None = None,
*,
@@ -226,6 +264,12 @@ def forward(
@rtype: L{LuxonisOutput}
@return: Output of the model.
"""
+ if isinstance(inputs, Tensor):
+ inputs = {self.image_source: inputs}
+
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+ if labels is not None:
+ labels = {k: v.to(self.device) for k, v in labels.items()}
losses: dict[
str, dict[str, Tensor | tuple[Tensor, dict[str, Tensor]]]
] = defaultdict(dict)
@@ -252,10 +296,14 @@ def forward(
if compute_loss and node.losses and labels is not None:
for loss_name, loss in node.losses.items():
+ loss.to(self.device)
+ if self.training:
+ loss.train()
losses[node_name][loss_name] = loss.run(outputs, labels)
if compute_metrics and node.metrics and labels is not None:
for metric in node.metrics.values():
+ metric.to(self.device)
metric.run_update(outputs, labels)
if (
@@ -264,6 +312,7 @@ def forward(
and images is not None
):
for viz_name, visualizer in node.visualizers.items():
+ visualizer.to(self.device)
viz = combine_visualizations(
visualizer.run(images, images, outputs, labels),
)
@@ -288,10 +337,24 @@ def forward(
outputs=outputs_dict, losses=losses, visualizations=visualizations
)
- def set_export_mode(self, *, mode: bool) -> None:
+ @override
+ def train(self, mode: bool = True) -> Self:
+ super().train(mode)
+ for node in self.nodes.values():
+ node.train(mode)
+ return self
+
+ def set_export_mode(self, mode: bool) -> Self:
for module in self.modules():
if isinstance(module, BaseNode):
module.set_export_mode(mode=mode)
+ return self
+
+ def reparametrize(self) -> Self:
+ for module in self.modules():
+ if isinstance(module, Reparametrizable):
+ module.reparametrize()
+ return self
def export_onnx(self, save_path: PathType, **kwargs) -> Path:
"""Exports the model to ONNX format.
@@ -304,7 +367,7 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path:
@rtype: Path
@return: Path to the exported model.
"""
- device_before = self.device
+ device = self.device
self.eval()
self.to("cpu") # move to CPU to support deterministic .to_onnx()
@@ -314,88 +377,13 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path:
for shapes in self.nodes.loader_input_shapes.values()
for input_name, shape in shapes.items()
}
-
- inputs_deep_clone = {
- k: torch.zeros(elem.shape).to(self.device)
- for k, elem in inputs.items()
- }
-
- inputs_for_onnx = {"inputs": inputs_deep_clone}
+ if "input_names" not in kwargs:
+ kwargs["input_names"] = list(inputs.keys())
self.set_export_mode(mode=True)
- outputs = self.forward(inputs_deep_clone).outputs
- output_order = sorted(
- [
- (node_name, output_name, i)
- for node_name, outs in outputs.items()
- for output_name, out in outs.items()
- for i in range(len(out))
- ]
- )
-
- output_counts = defaultdict(int)
- for node_name, outs in outputs.items():
- output_counts[node_name] = sum(len(out) for out in outs.values())
-
- export_output_names_dict = {}
- for node_name, node in self.nodes.items():
- if node.module.export_output_names is not None:
- if (
- len(node.module.export_output_names)
- != output_counts[node_name]
- ):
- logger.warning(
- f"Number of provided output names for node {node_name} "
- f"({len(node.module.export_output_names)}) does not match "
- f"number of outputs ({output_counts[node_name]}). "
- f"Using default names."
- )
- else:
- export_output_names_dict[node_name] = (
- node.module.export_output_names
- )
-
- output_names = []
- # For cases where export_output_names should be used but
- # output node's output is split into multiple subnodes
- running_i = {}
- for node_name, output_name, i in output_order:
- if node_name in export_output_names_dict:
- running_i[node_name] = (
- running_i.get(node_name, -1) + 1
- ) # if not present default to 0 otherwise add 1
- output_names.append(
- export_output_names_dict[node_name][running_i[node_name]]
- )
- else:
- output_names.append(
- f"{self.nodes[node_name].task_name}/{node_name}/{output_name}/{i}"
- )
-
- old_forward = self.forward
-
- def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]:
- old_outputs = old_forward(
- inputs,
- None,
- compute_loss=False,
- compute_metrics=False,
- compute_visualizations=False,
- ).outputs
- outputs = []
- for node_name, output_name, i in output_order:
- node_output = old_outputs[node_name][output_name]
- if isinstance(node_output, Tensor):
- outputs.append(node_output)
- else:
- outputs.append(node_output[i])
- return tuple(outputs)
-
- self.forward = export_forward # type: ignore
+ output_names = self._get_output_onnx_names(deepcopy(inputs))
- if "input_names" not in kwargs:
- kwargs["input_names"] = list(inputs.keys())
if "output_names" not in kwargs:
kwargs["output_names"] = output_names
@@ -403,23 +391,22 @@ def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]:
# PyTorch 2.9 introduces a breaking change that
# sets the default value to True
kwargs.setdefault("dynamo", False)
- self.to_onnx(save_path, inputs_for_onnx, **kwargs)
- self.forward = old_forward # type: ignore
+ self.to_onnx(save_path, {"inputs": inputs}, **kwargs)
logger.info(f"Model exported to {save_path}")
self.set_export_mode(mode=False)
self.train()
- self.to(device_before) # reset device after export
+ self.to(device) # reset device after export
return Path(save_path)
@override
def training_step(
- self, train_batch: tuple[dict[str, Tensor], Labels]
+ self, train_batch: tuple[dict[str, Tensor] | Tensor, Labels]
) -> Tensor:
- outputs = self.forward(*train_batch)
+ outputs = self.full_forward(*train_batch)
if not outputs.losses:
raise ValueError("Losses are empty, check if you defined any loss")
@@ -429,23 +416,26 @@ def training_step(
@override
def validation_step(
- self, val_batch: tuple[dict[str, Tensor], Labels]
+ self, val_batch: tuple[dict[str, Tensor] | Tensor, Labels]
) -> dict[str, Tensor]:
return self._evaluation_step("val", *val_batch)
@override
def test_step(
- self, test_batch: tuple[dict[str, Tensor], Labels]
+ self, test_batch: tuple[dict[str, Tensor] | Tensor, Labels]
) -> dict[str, Tensor]:
return self._evaluation_step("test", *test_batch)
@override
def predict_step(
- self, batch: tuple[dict[str, Tensor], Labels]
+ self, batch: tuple[dict[str, Tensor] | Tensor, Labels]
) -> LuxonisOutput:
inputs, labels = batch
- images = get_denormalized_images(self.cfg, inputs[self.image_source])
- return self.forward(
+ images = get_denormalized_images(
+ self.cfg,
+ inputs[self.image_source] if isinstance(inputs, dict) else inputs,
+ )
+ return self.full_forward(
inputs,
labels,
images=images,
@@ -668,6 +658,14 @@ def load_checkpoint(self, ckpt: PathType | dict[str, Any] | None) -> None:
sub_state_dict, strict=False
)
+ def detach(self) -> None:
+ """Detaches the model from the trainer.
+
+ This is useful when the model needs to be used outside of the
+ training loop, for example for inference or exporting.
+ """
+ self.trainer = None
+
def _check_valid_epoch_counts(self, ckpt_config: dict) -> None:
previous_trainer_cfg = ckpt_config.get("trainer", {})
previous_epochs = previous_trainer_cfg.get("epochs", None)
@@ -685,10 +683,12 @@ def _check_valid_epoch_counts(self, ckpt_config: dict) -> None:
def _evaluation_step(
self,
mode: Literal["test", "val"],
- inputs: dict[str, Tensor],
+ inputs: dict[str, Tensor] | Tensor,
labels: Labels,
) -> dict[str, Tensor]:
max_log_images = self.cfg.trainer.n_log_images
+ if isinstance(inputs, Tensor):
+ inputs = {self.image_source: inputs}
input_image = inputs[self.image_source]
# Smart logging is decided based on the classification task keys that are merged for all tasks
@@ -699,7 +699,7 @@ def _evaluation_step(
if self._n_logged_images < max_log_images:
images = get_denormalized_images(self.cfg, input_image)
- outputs = self.forward(
+ outputs = self.full_forward(
inputs,
labels,
images=images,
@@ -955,25 +955,22 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]:
)
for callback in self.cfg.trainer.callbacks:
+ model_name = self.cfg.exporter.name or self.cfg.model.name
if callback.name == "UploadCheckpoint":
artifact_keys.update(
{"best_val_metric.ckpt", "min_val_loss.ckpt"}
)
elif callback.name == "ExportOnTrainEnd":
- artifact_keys.add(
- f"{self.cfg.exporter.name or self.cfg.model.name}.onnx"
- )
+ artifact_keys.add(f"{model_name}.onnx")
elif callback.name == "ArchiveOnTrainEnd":
- artifact_keys.add(
- f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz"
- )
+ artifact_keys.add(f"{model_name}.onnx.tar.xz")
elif callback.name == "ConvertOnTrainEnd":
- artifact_keys.add(
- f"{self.cfg.exporter.name or self.cfg.model.name}.onnx"
- )
- artifact_keys.add(
- f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz"
- )
+ artifact_keys.add(f"{model_name}.onnx")
+ artifact_keys.add(f"{model_name}.onnx.tar.xz")
+ elif callback.name == "AIMETCallback":
+ artifact_keys.add(f"{model_name}.onnx")
+ artifact_keys.add(f"{model_name}.onnx.data")
+ artifact_keys.add(f"{model_name}.encodings")
elif callback.name == "TrainingProgressCallback":
metric_keys.update(
{
@@ -997,6 +994,20 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]:
"artifacts": sorted(artifact_keys),
}
+ @override
+ def __getstate__(self):
+ state = super().__getstate__()
+ state["_core"] = None
+ _INTERNAL["trainer"] = self._trainer
+ _INTERNAL["core"] = self._core
+ return state
+
+ @override
+ def __setstate__(self, state: Any):
+ super().__setstate__(state)
+ self._trainer = _INTERNAL.get("trainer") # type: ignore
+ self._core = _INTERNAL.get("core")
+
def _get_node_order_mapping(
self, node_name: str, old_order: list[str], new_order: list[str]
) -> dict[str, str]:
@@ -1024,6 +1035,57 @@ def _strip_state_prefix(key: str) -> str:
idx = 3 if "module." in key else 2
return ".".join(key.split(".")[idx:])
+ def _get_output_onnx_names(self, inputs: dict[str, Tensor]) -> list[str]:
+ outputs = self.full_forward(inputs).outputs
+ output_order = sorted(
+ [
+ (node_name, output_name, i)
+ for node_name, outs in outputs.items()
+ for output_name, out in outs.items()
+ for i in range(len(out))
+ ]
+ )
+
+ output_counts = defaultdict(int)
+ for node_name, outs in outputs.items():
+ output_counts[node_name] = sum(len(out) for out in outs.values())
+
+ export_output_names_dict = {}
+ for node_name, node in self.nodes.items():
+ if node.module.export_output_names is not None:
+ if (
+ len(node.module.export_output_names)
+ != output_counts[node_name]
+ ):
+ logger.warning(
+ f"Number of provided output names for node {node_name} "
+ f"({len(node.module.export_output_names)}) does not match "
+ f"number of outputs ({output_counts[node_name]}). "
+ f"Using default names."
+ )
+ else:
+ export_output_names_dict[node_name] = (
+ node.module.export_output_names
+ )
+
+ output_names = []
+ # For cases where export_output_names should be used but
+ # output node's output is split into multiple subnodes
+ running_i = {}
+ for node_name, output_name, i in output_order:
+ if node_name in export_output_names_dict:
+ running_i[node_name] = (
+ running_i.get(node_name, -1) + 1
+ ) # if not present default to 0 otherwise add 1
+ output_names.append(
+ export_output_names_dict[node_name][running_i[node_name]]
+ )
+ else:
+ output_names.append(
+ f"{self.nodes[node_name].task_name}/{node_name}/{output_name}/{i}"
+ )
+ return output_names
+
def _add_custom_data_to_checkpoint(
self, checkpoint: dict[str, Any]
) -> None:
diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py
index 6ff099023..34dbaccef 100644
--- a/luxonis_train/lightning/utils.py
+++ b/luxonis_train/lightning/utils.py
@@ -16,6 +16,7 @@
from torch import Size, Tensor, nn
from torch.optim.lr_scheduler import LRScheduler, SequentialLR
from torch.optim.optimizer import Optimizer
+from typing_extensions import override
import luxonis_train as lxt
from luxonis_train.attached_modules import BaseLoss, BaseMetric, BaseVisualizer
@@ -23,6 +24,7 @@
BaseAttachedModule,
)
from luxonis_train.callbacks import LuxonisModelSummary, TrainingManager
+from luxonis_train.callbacks.aimet_callback import AIMETCallback
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.config.config import NodeConfig
from luxonis_train.nodes import BaseNode
@@ -51,7 +53,7 @@ class MainMetric(NamedTuple):
class LossAccumulator(defaultdict[str, float]):
- def __init__(self):
+ def __init__(self, *args, **kwargs):
super().__init__(float)
self.counts = defaultdict(int)
@@ -82,9 +84,9 @@ def __init__(
super().__init__()
self.name = name
self.module = module
- self.losses = _to_module_dict(losses)
- self.metrics = _to_module_dict(metrics)
- self.visualizers = _to_module_dict(visualizers)
+ self.losses = losses
+ self.metrics = metrics
+ self.visualizers = visualizers
self.unfreeze_after = unfreeze_after
self.lr_after_unfreeze = lr_after_unfreeze
self.inputs = inputs or []
@@ -93,6 +95,17 @@ def __init__(
def task_name(self) -> str:
return self.module.task_name
+ @override
+ def train(self, mode: bool = True) -> "NodeWrapper":
+ self.module.train(mode)
+ for loss in self.losses.values():
+ loss.train(mode)
+ for metric in self.metrics.values():
+ metric.train(mode)
+ for visualizer in self.visualizers.values():
+ visualizer.train(mode)
+ return self
+
class Nodes(dict[str, NodeWrapper] if TYPE_CHECKING else nn.ModuleDict):
def __init__(
@@ -493,6 +506,9 @@ def build_callbacks(
"in the callbacks list. The `accumulate_grad_batches` "
"parameter in the config will be ignored."
)
+ if cfg.exporter.aimet.active:
+ callbacks.append(AIMETCallback())
+
if main_metric is not None:
node_name, metric_name = main_metric
formatted_node = nodes.formatted_name(node_name)
@@ -566,10 +582,6 @@ def _init_attached_module(
A = TypeVar("A", BaseLoss, BaseMetric, BaseVisualizer)
-def _to_module_dict(modules: dict[str, A]) -> dict[str, A]:
- return nn.ModuleDict(modules) # type: ignore
-
-
def log_balanced_class_images(
tracker: LuxonisTrackerPL,
nodes: Nodes,
diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py
index 16992faa7..f2dd5b1b7 100644
--- a/luxonis_train/loaders/base_loader.py
+++ b/luxonis_train/loaders/base_loader.py
@@ -16,7 +16,7 @@
from luxonis_train.typing import Labels
from luxonis_train.utils.general import get_attribute_check_none
-LuxonisLoaderTorchOutput = tuple[dict[str, Tensor], Labels]
+LuxonisLoaderTorchOutput = tuple[dict[str, Tensor] | Tensor, Labels]
class BaseLoaderTorch(
@@ -224,19 +224,8 @@ def augment_test_image(self, img: dict[str, Tensor]) -> Tensor:
"`augment_test_image` method to expose this functionality."
)
- def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput:
- img, labels = self.get(idx)
- if isinstance(img, Tensor):
- img = {self.image_source: img}
- return img, labels
-
- @abstractmethod
- def __len__(self) -> int:
- """Returns length of the dataset."""
- ...
-
@abstractmethod
- def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]:
+ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput:
"""Loads sample from dataset.
@type idx: int
@@ -246,6 +235,11 @@ def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]:
"""
...
+ @abstractmethod
+ def __len__(self) -> int:
+ """Returns length of the dataset."""
+ ...
+
@abstractmethod
def get_classes(self) -> dict[str, dict[str, int]]:
"""Gets classes according to computer vision task.
@@ -334,7 +328,7 @@ def img_numpy_to_torch(img: np.ndarray) -> Tensor:
def collate_fn(
self,
batch: list[LuxonisLoaderTorchOutput],
- ) -> tuple[dict[str, Tensor], Labels]:
+ ) -> tuple[dict[str, Tensor] | Tensor, Labels]:
"""Default collate function used for training.
@type batch: list[LuxonisLoaderTorchOutput]
@@ -345,13 +339,20 @@ def collate_fn(
@return: Tuple of inputs and annotations in the format expected
by the model.
"""
- inputs: tuple[dict[str, Tensor], ...]
+ inputs: tuple[dict[str, Tensor], ...] | tuple[Tensor, ...]
labels: tuple[Labels, ...]
inputs, labels = zip(*batch, strict=True)
- out_inputs = {
- k: torch.stack([i[k] for i in inputs], 0) for k in inputs[0]
- }
+ if isinstance(inputs[0], dict):
+ out_inputs = {
+ k: torch.stack(
+ [i[k] for i in inputs], # type: ignore
+ 0,
+ )
+ for k in inputs[0]
+ }
+ else:
+ out_inputs = torch.stack(inputs, 0) # type: ignore
out_labels: Labels = {}
diff --git a/luxonis_train/loaders/dummy_loader.py b/luxonis_train/loaders/dummy_loader.py
index 2b1f3fed4..7f2cc06da 100644
--- a/luxonis_train/loaders/dummy_loader.py
+++ b/luxonis_train/loaders/dummy_loader.py
@@ -86,7 +86,9 @@ def __len__(self) -> int:
return self.batch_size * 10
@override
- def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]:
+ def __getitem__(
+ self, idx: int
+ ) -> tuple[Tensor | dict[str, Tensor], Labels]:
img = torch.zeros(self.n_channels, self.height, self.width)
label_shapes = self.get_label_shapes(self.labels)
labels = {
diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py
index fe1713e2b..33349e2a0 100644
--- a/luxonis_train/loaders/luxonis_loader_torch.py
+++ b/luxonis_train/loaders/luxonis_loader_torch.py
@@ -151,11 +151,15 @@ def __len__(self) -> int:
@property
@override
def input_shapes(self) -> dict[str, Size]:
- img = self[0][0][self.image_source]
+ img = self[0][0]
+ if isinstance(img, dict):
+ img = img[self.image_source]
return {self.image_source: img.shape}
@override
- def get(self, idx: int) -> tuple[dict[str, Tensor], Labels]:
+ def __getitem__(
+ self, idx: int
+ ) -> tuple[dict[str, Tensor] | Tensor, Labels]:
img, labels = self.loader[idx]
if isinstance(img, np.ndarray):
img = {self.image_source: img}
@@ -164,6 +168,8 @@ def get(self, idx: int) -> tuple[dict[str, Tensor], Labels]:
labels = self._remap_keypoints(labels)
img = {k: self.img_numpy_to_torch(v) for k, v in img.items()}
+ if len(img) == 1:
+ img = next(iter(img.values()))
return img, self.dict_numpy_to_torch(labels)
diff --git a/luxonis_train/loaders/luxonis_perlin_loader_torch.py b/luxonis_train/loaders/luxonis_perlin_loader_torch.py
index 27c43b88e..867f6def4 100644
--- a/luxonis_train/loaders/luxonis_perlin_loader_torch.py
+++ b/luxonis_train/loaders/luxonis_perlin_loader_torch.py
@@ -78,7 +78,7 @@ def __init__(
self.augmentations = self.loader.augmentations
@override
- def get(self, idx: int) -> tuple[Tensor, Labels]:
+ def __getitem__(self, idx: int) -> tuple[Tensor, Labels]:
with _freeze_seed():
img, labels = self.loader[idx]
if isinstance(img, dict):
diff --git a/luxonis_train/nodes/backbones/dinov3/dinov3.py b/luxonis_train/nodes/backbones/dinov3/dinov3.py
index 9e88e3cce..20fcf813e 100644
--- a/luxonis_train/nodes/backbones/dinov3/dinov3.py
+++ b/luxonis_train/nodes/backbones/dinov3/dinov3.py
@@ -206,6 +206,7 @@ def _get_backbone(
model=model_name,
weights=weights,
source="github",
+ trust_repo=True, # type: ignore
**kwargs,
)
diff --git a/luxonis_train/nodes/backbones/efficientnet.py b/luxonis_train/nodes/backbones/efficientnet.py
index d5b747280..c92c746b6 100644
--- a/luxonis_train/nodes/backbones/efficientnet.py
+++ b/luxonis_train/nodes/backbones/efficientnet.py
@@ -43,6 +43,7 @@ class GenEfficientNet(nn.Module):
"rwightman/gen-efficientnet-pytorch",
"efficientnet_lite0",
pretrained=weights == "download",
+ trust_repo=True, # type: ignore
),
)
self.out_indices = out_indices or [0, 1, 2, 4, 6]
diff --git a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py
index 57f3dc63d..20913ab18 100644
--- a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py
+++ b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py
@@ -1,3 +1,5 @@
+from contextlib import suppress
+
import torch
from torch import Tensor, nn
from typeguard import typechecked
@@ -31,6 +33,34 @@ def forward(self, x: Tensor) -> Tensor:
return self.scale * x + self.bias
+with suppress(ImportError):
+ from aimet_torch.v2.nn import QuantizationMixin
+
+ @QuantizationMixin.implements(AffineBlock)
+ class QuantizedAffineBlock(QuantizationMixin, AffineBlock):
+ def __quant_init__(self):
+ super().__quant_init__()
+
+ # Declare the number of input/output quantizers
+ self.input_quantizers = nn.ModuleList([None]) # type: ignore
+ self.output_quantizers = nn.ModuleList([None]) # type: ignore
+
+ def forward(self, x: Tensor) -> Tensor:
+ # Quantize input tensors
+ if self.input_quantizers[0]:
+ x = self.input_quantizers[0](x)
+
+ # Run forward with quantized inputs and parameters
+ with self._patch_quantized_parameters():
+ ret = super().forward(x)
+
+ # Quantize output tensors
+ if self.output_quantizers[0]:
+ ret = self.output_quantizers[0](ret)
+
+ return ret
+
+
class LCNetV3Block(nn.Module):
@typechecked
def __init__(
diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py
index 9e9eb783a..ffc323dd5 100644
--- a/luxonis_train/nodes/blocks/__init__.py
+++ b/luxonis_train/nodes/blocks/__init__.py
@@ -1,3 +1,5 @@
+from contextlib import suppress
+
from .blocks import (
DFL,
AttentionRefinmentBlock,
@@ -27,6 +29,12 @@
UpBlock,
)
+with suppress(ImportError):
+ from aimet_torch.v2.nn import QuantizationMixin
+
+ QuantizationMixin.ignore(DropPath)
+ QuantizationMixin.ignore(UpscaleOnline)
+
__all__ = [
"DFL",
"AttentionRefinmentBlock",
diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py
index 79e27a5da..172104d8e 100644
--- a/luxonis_train/nodes/blocks/blocks.py
+++ b/luxonis_train/nodes/blocks/blocks.py
@@ -467,7 +467,7 @@ def name(self) -> str:
@override
def reparametrize(self) -> None:
if self.fused_branch is not None:
- raise RuntimeError(f"{self.name} is already reparametrized")
+ return
kernel, bias = self._fuse_parameters()
fused_branch = nn.Conv2d(
@@ -489,10 +489,7 @@ def reparametrize(self) -> None:
@override
def restore(self) -> None:
if self.fused_branch is None:
- raise RuntimeError(
- f"Cannot restore '{self.name}' "
- "that has not yet been reparametrized."
- )
+ return
# Not sure if this is necessary
for param in self.fused_branch.parameters():
diff --git a/luxonis_train/registry.py b/luxonis_train/registry.py
index 258679e5b..f1c60bdf8 100644
--- a/luxonis_train/registry.py
+++ b/luxonis_train/registry.py
@@ -1,7 +1,7 @@
"""This module implements a metaclass for automatic registration of
classes."""
-from typing import TYPE_CHECKING, TypeVar
+from typing import TYPE_CHECKING, Any, TypeVar
from luxonis_ml.utils.registry import Registry
@@ -34,6 +34,8 @@
VISUALIZERS: Registry[type["lxt.BaseVisualizer"]] = Registry("visualizers")
+_INTERNAL: dict[str, Any] = {}
+
T = TypeVar("T")
diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py
index 2a10ebf4f..ac942c71a 100644
--- a/luxonis_train/utils/__init__.py
+++ b/luxonis_train/utils/__init__.py
@@ -63,27 +63,18 @@
"filter_checkpoint_state_dict",
"get_attribute_check_none",
"get_batch_instances",
- "get_batch_instances",
"get_center_keypoints",
"get_sigmas",
"get_with_default",
"infer_upscale_factor",
"insert_class",
"instances_from_batch",
- "instances_from_batch",
- "instances_from_batch",
- "instances_from_batch",
"keypoints_to_bboxes",
"make_divisible",
- "make_divisible",
"non_max_suppression",
- "non_max_suppression",
- "safe_download",
"safe_download",
"seg_output_to_bool",
"setup_logging",
- "setup_logging",
- "to_shape_packet",
"to_shape_packet",
"transform_boxes",
"transform_keypoints",
diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py
index 04623f72b..7691e65be 100644
--- a/luxonis_train/utils/dataset_metadata.py
+++ b/luxonis_train/utils/dataset_metadata.py
@@ -103,7 +103,8 @@ def n_classes(self, task_name: str | None = None) -> int:
if task_name is not None:
if task_name not in self._classes:
raise ValueError(
- f"Task '{task_name}' is not present in the dataset."
+ f"Task '{task_name}' is not present in the dataset. "
+ f"Available tasks: {self.task_names}"
)
return len(self._classes[task_name])
n_classes = len(next(iter(self._classes.values())))
@@ -121,19 +122,14 @@ def n_keypoints(self, task_name: str | None = None) -> int:
@type task_name: str | None
@param task_name: Task to get the number of keypoints for.
@rtype: int
- @return: Number of keypoints for the specified task type.
- @raises ValueError: If the C{task} is not present in the
- dataset.
+ @return: Number of keypoints for the specified task type or 0 if
+ the task does not involve keypoints.
@raises RuntimeError: If the C{task} was not provided and the
dataset contains different number of keypoints for different
task types.
"""
if task_name is not None:
- if task_name not in self._n_keypoints:
- raise ValueError(
- f"Task '{task_name}' is not present in the dataset."
- )
- return self._n_keypoints[task_name]
+ return self._n_keypoints.get(task_name, 0)
n_keypoints = next(iter(self._n_keypoints.values()))
for n in self._n_keypoints.values():
if n != n_keypoints:
@@ -160,7 +156,8 @@ def classes(self, task_name: str | None = None) -> bidict[str, int]:
if task_name is not None:
if task_name not in self._classes:
raise ValueError(
- f"Task '{task_name}' is not present in the dataset."
+ f"Task '{task_name}' is not present in the dataset. "
+ f"Available tasks: {self.task_names}"
)
return bidict(self._classes[task_name])
classes = next(iter(self._classes.values()))
diff --git a/media/anomaly_detection_diagram.drawio b/media/anomaly_detection_diagram.drawio
index ee134024a..7411296c6 100644
--- a/media/anomaly_detection_diagram.drawio
+++ b/media/anomaly_detection_diagram.drawio
@@ -113,7 +113,7 @@
-
+
diff --git a/pyproject.toml b/pyproject.toml
index 77b738169..968810f9b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -35,8 +35,10 @@ where = ["."]
[tool.setuptools.dynamic]
dependencies = { file = ["requirements.txt"] }
-optional-dependencies = { dev = { file = ["requirements-dev.txt"] } }
version = { attr = "luxonis_train.__version__" }
+[tool.setuptools.dynamic.optional-dependencies]
+aimet = { file = ["requirements-aimet.txt"] }
+dev = { file = ["requirements-dev.txt"] }
[tool.ruff]
target-version = "py310"
diff --git a/requirements-aimet.txt b/requirements-aimet.txt
new file mode 100644
index 000000000..78ab17fbe
--- /dev/null
+++ b/requirements-aimet.txt
@@ -0,0 +1,3 @@
+aimet-torch~=2.31
+torch==2.11
+torchvision~=0.26
diff --git a/requirements.txt b/requirements.txt
index 490eb03d8..44546ef2b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,7 +19,6 @@ semver~=3.0
tabulate~=0.9
tensorboard~=2.20
termcolor~=3.2
-torch<2.11 # temporary pin: GitHub runner image does not yet support the newer NVIDIA driver requirement used by torch 2.11 wheels
torchmetrics~=1.8
torchvision~=0.24
hubai-sdk>=0.2.5
diff --git a/tests/conftest.py b/tests/conftest.py
index 5fab5f96a..a68d8f08e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -405,6 +405,18 @@ def opts(save_dir: Path, image_size: tuple[int, int]) -> Params:
],
"tracker.save_directory": str(save_dir),
"trainer.preprocessing.train_image_size": image_size,
+ "exporter.aimet": {
+ "active": False,
+ "epochs": 1,
+ "fold_batch_norms": True,
+ "batch_norm_reestimation": True,
+ "cross_layer_equalization": True,
+ "sequential_mse": True,
+ },
+ "exporter.aimet.adaround": {
+ "active": True,
+ "default_num_iterations": 1,
+ },
}
diff --git a/tests/integration/backbone_model_utils.py b/tests/integration/backbone_model_utils.py
index b3a3ddf6c..725d46285 100644
--- a/tests/integration/backbone_model_utils.py
+++ b/tests/integration/backbone_model_utils.py
@@ -76,5 +76,7 @@ def prepare_predefined_model_config(
}
elif "ocr_recognition" in config_file:
opts["trainer.preprocessing.train_image_size"] = [48, 320]
+ elif "instance_segmentation" in config_name:
+ opts |= {"exporter.aimet.batch_norm_reestimation": False}
return config_file, opts, dataset
diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py
index 0e782b0a5..070ca4475 100644
--- a/tests/integration/test_callbacks.py
+++ b/tests/integration/test_callbacks.py
@@ -56,6 +56,8 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path):
"exporter.scale_values": [0.5, 0.5, 0.5],
"exporter.mean_values": [0.5, 0.5, 0.5],
"exporter.blobconverter.active": True,
+ # AIMET fails when determinism is enabled
+ "exporter.aimet.active": False,
"loader.params.dataset_name": coco_dataset.identifier,
}
model = LuxonisModel(config_file, opts, allow_empty_dataset=True)
diff --git a/tests/integration/test_combinations.py b/tests/integration/test_combinations.py
index 2c3048b69..d9a5b4825 100644
--- a/tests/integration/test_combinations.py
+++ b/tests/integration/test_combinations.py
@@ -143,7 +143,9 @@ def test_combinations(
subtests: SubTests,
):
config = get_config(backbone, dinov3_weights)
- opts |= {"loader.params.dataset_name": parking_lot_dataset.identifier}
+ opts |= {
+ "loader.params.dataset_name": parking_lot_dataset.identifier,
+ }
model = LuxonisModel(config, opts)
with subtests.test("train"):
diff --git a/tests/integration/test_custom_model.py b/tests/integration/test_custom_model.py
index c4915bc5a..8c2cc8f18 100644
--- a/tests/integration/test_custom_model.py
+++ b/tests/integration/test_custom_model.py
@@ -30,7 +30,7 @@ def input_shapes(self):
"pointcloud": torch.Size([self.n_points, 3]),
}
- def get(self, _: int) -> LuxonisLoaderTorchOutput:
+ def __getitem__(self, _: int) -> LuxonisLoaderTorchOutput:
left = torch.rand(3, self.height, self.width, dtype=torch.float32)
right = torch.rand(3, self.height, self.width, dtype=torch.float32)
disparity = torch.rand(1, self.height, self.width, dtype=torch.float32)
diff --git a/tests/integration/test_predefined_models.py b/tests/integration/test_predefined_models.py
index 053c62fc9..145b8b8c4 100644
--- a/tests/integration/test_predefined_models.py
+++ b/tests/integration/test_predefined_models.py
@@ -65,6 +65,12 @@ def test_predefined_models(
model.run_save_dir / "archive" / f"{config_name}.onnx.tar.xz"
).exists()
+ with subtests.test("quantize"):
+ save_dir = model.quantize()
+ assert (save_dir / f"{config_name}.encodings").exists()
+ assert (save_dir / f"{config_name}.onnx").exists()
+ assert (save_dir / f"{config_name}.onnx.data").exists()
+
if config_name != "embeddings_model":
with subtests.test("infer"):
loader = LuxonisLoader(dataset)
diff --git a/tests/unittests/test_loaders/test_base_loader.py b/tests/unittests/test_loaders/test_base_loader.py
index 4be1e5c9d..295f6a4a7 100644
--- a/tests/unittests/test_loaders/test_base_loader.py
+++ b/tests/unittests/test_loaders/test_base_loader.py
@@ -9,7 +9,7 @@
class DummyLoader(BaseLoaderTorch):
def __len__(self) -> int: ...
- def get(self, idx: int) -> LuxonisLoaderTorchOutput: ...
+ def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: ...
def get_classes(self) -> dict[str, dict[str, int]]: ...
@@ -68,6 +68,7 @@ def build_batch_element() -> LuxonisLoaderTorchOutput:
inputs, annotations = loader.collate_fn(batch)
with subtests.test("inputs"):
+ assert isinstance(inputs, dict)
assert inputs["features"].shape == (batch_size, 3, 224, 224)
assert inputs["features"].dtype == torch.float32
diff --git a/tests/unittests/test_utils/test_dataset_metadata.py b/tests/unittests/test_utils/test_dataset_metadata.py
index ee2f7478a..ceb75c818 100644
--- a/tests/unittests/test_utils/test_dataset_metadata.py
+++ b/tests/unittests/test_utils/test_dataset_metadata.py
@@ -29,8 +29,7 @@ def test_n_keypoints(metadata: DatasetMetadata):
assert metadata.n_keypoints("color-segmentation") == 0
assert metadata.n_keypoints("detection") == 0
assert metadata.n_keypoints() == 0
- with pytest.raises(ValueError, match="Task 'segmentation'"):
- metadata.n_keypoints("segmentation")
+ assert metadata.n_keypoints("segmentation") == 0
metadata._n_keypoints["segmentation"] = 1
with pytest.raises(RuntimeError, match="different number of keypoints"):
metadata.n_keypoints()