From b26e70f9c6b1ae9361e3ba1622c3805cb03d1ab1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 22 Mar 2026 19:16:44 -0700 Subject: [PATCH 01/74] aimet --- luxonis_train/callbacks/aimet_callback.py | 34 +++++++++++++++++++++++ luxonis_train/core/core.py | 8 ++++++ 2 files changed, 42 insertions(+) create mode 100644 luxonis_train/callbacks/aimet_callback.py diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py new file mode 100644 index 00000000..eedcb788 --- /dev/null +++ b/luxonis_train/callbacks/aimet_callback.py @@ -0,0 +1,34 @@ +from typing import Literal + +import lightning.pytorch as pl +from loguru import logger + +import luxonis_train as lxt +from luxonis_train.callbacks.needs_checkpoint import NeedsCheckpoint + + +class AIMETCallback(NeedsCheckpoint): + def __init__(self, mode: Literal["PTQ", "QAT"]): + super().__init__() + + def on_train_end( + self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" + ) -> None: + onnx_path = pl_module.core._exported_models.get("onnx") + if onnx_path is None: # pragma: no cover + checkpoint = self.get_checkpoint(pl_module) + if checkpoint is None: + logger.warning("Skipping model archiving.") + return + logger.info("Exported model not found. Exporting to ONNX...") + pl_module.core.export(weights=checkpoint) + onnx_path = pl_module.core._exported_models.get("onnx") + + if onnx_path is None: # pragma: no cover + logger.error( + "Model executable not found and couldn't be created. " + "Skipping AIMET." + ) + return + + pl_module.core.quantize(onnx_path, mode) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2e11468b..1f84bf6c 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -6,10 +6,12 @@ import lightning.pytorch as pl import lightning_utilities.core.rank_zero as rank_zero_module +import onnx import rich.traceback import torch import torch.utils.data as torch_data import yaml +from aimet_onnx.quantsim import QuantizationSimModel from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only from loguru import logger @@ -1139,6 +1141,12 @@ def convert( return archive_path, conversion_artifacts + def quantize( + self, onnx_path: PathType, mode: Literal["PTQ", "QAT"] + ) -> None: + model = onnx.load_model(onnx_path) + sim = QuantizationSimModel(model=model) + @property def environ(self) -> Environ: return self.cfg.ENVIRON From ba56cb92be2c496e785ef87c235a0fa8c84ec973 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 29 Mar 2026 17:35:05 +0200 Subject: [PATCH 02/74] ptq --- luxonis_train/__main__.py | 19 +- luxonis_train/callbacks/__init__.py | 2 + luxonis_train/callbacks/aimet_callback.py | 25 +- .../callbacks/luxonis_progress_bar.py | 71 ++++-- luxonis_train/core/core.py | 99 ++++++-- luxonis_train/core/utils/infer_utils.py | 4 +- luxonis_train/lightning/luxonis_lightning.py | 221 ++++++++++-------- 7 files changed, 283 insertions(+), 158 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index e3d32b1c..6e6d356d 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -75,10 +75,13 @@ def create_model( ) return LuxonisModel( - cfg, debug_mode=debug_mode, dataset_metadata=dataset_metadata + cfg, + debug_mode=debug_mode, + dataset_metadata=dataset_metadata, + weights=weights, ) - return LuxonisModel(config, opts, debug_mode=debug_mode) + return LuxonisModel(config, opts, weights=weights, debug_mode=debug_mode) @app.command(group=training_group, sort_key=1) @@ -431,6 +434,18 @@ def convert( ) +@app.command(group=export_group, sort_key=1) +def quantize( + opts: list[str] | None = None, + /, + *, + config: str | None = None, + weights: str | None = None, +): + model = create_model(config, opts, weights=weights, debug_mode=True) + model.quantize() + + @upgrade_app.command() def config( config: Annotated[ diff --git a/luxonis_train/callbacks/__init__.py b/luxonis_train/callbacks/__init__.py index 3b29218f..3fbde3ff 100644 --- a/luxonis_train/callbacks/__init__.py +++ b/luxonis_train/callbacks/__init__.py @@ -11,6 +11,7 @@ from luxonis_train.registry import CALLBACKS +from .aimet_callback import AIMETCallback from .archive_on_train_end import ArchiveOnTrainEnd from .convert_on_train_end import ConvertOnTrainEnd from .ema import EMACallback @@ -50,6 +51,7 @@ __all__ = [ + "AIMETCallback", "ArchiveOnTrainEnd", "BaseLuxonisProgressBar", "ConvertOnTrainEnd", diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py index eedcb788..09033dd4 100644 --- a/luxonis_train/callbacks/aimet_callback.py +++ b/luxonis_train/callbacks/aimet_callback.py @@ -1,34 +1,19 @@ from typing import Literal import lightning.pytorch as pl -from loguru import logger import luxonis_train as lxt from luxonis_train.callbacks.needs_checkpoint import NeedsCheckpoint +from luxonis_train.registry import CALLBACKS +@CALLBACKS.register() class AIMETCallback(NeedsCheckpoint): - def __init__(self, mode: Literal["PTQ", "QAT"]): + def __init__(self, mode: Literal["PTQ", "QAT"] = "PTQ"): super().__init__() + self.mode = mode def on_train_end( self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" ) -> None: - onnx_path = pl_module.core._exported_models.get("onnx") - if onnx_path is None: # pragma: no cover - checkpoint = self.get_checkpoint(pl_module) - if checkpoint is None: - logger.warning("Skipping model archiving.") - return - logger.info("Exported model not found. Exporting to ONNX...") - pl_module.core.export(weights=checkpoint) - onnx_path = pl_module.core._exported_models.get("onnx") - - if onnx_path is None: # pragma: no cover - logger.error( - "Model executable not found and couldn't be created. " - "Skipping AIMET." - ) - return - - pl_module.core.quantize(onnx_path, mode) + pl_module.core.quantize() diff --git a/luxonis_train/callbacks/luxonis_progress_bar.py b/luxonis_train/callbacks/luxonis_progress_bar.py index 8dd8f2c1..3e5faa2a 100644 --- a/luxonis_train/callbacks/luxonis_progress_bar.py +++ b/luxonis_train/callbacks/luxonis_progress_bar.py @@ -1,6 +1,6 @@ import time from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from io import StringIO from typing import Any @@ -59,6 +59,24 @@ def print_results( """ ... + @abstractmethod + def print_table( + self, + title: str, + table: Iterable[tuple[str | int | float, ...]], + column_names: list[str], + ) -> None: + """Prints table to the console. + + @type title: str + @param title: Title of the table + @type table: Mapping[str, int | str | float] + @param table: Table to print + @type column_names: list[str] + @param column_names: Names of the columns in the table + """ + ... + def _log_progress(self, trainer: pl.Trainer) -> None: duration = ( time.time() - self._epoch_start_time @@ -129,7 +147,9 @@ def print_results( logger.info(f"Loss: {loss}") logger.info("Metrics:") for table_name, table in metrics.items(): - self._print_table(table_name, table) + self.print_table( + table_name, list(table.items()), ["Name", "Value"] + ) for matrix_name, matrix in matrices.get(table_name, {}).items(): self._print_matrix( self._format_matrix_title(matrix_name), matrix @@ -150,12 +170,12 @@ def _rule(self, title: str | None = None) -> None: else: logger.info("-----------------") - def _print_table( + @override + def print_table( self, title: str, - table: Mapping[str, int | str | float], - key_name: str = "Name", - value_name: str = "Value", + table: Iterable[tuple[str | int | float, ...]], + column_names: list[str], ) -> None: """Prints table to the console using tabulate. @@ -171,8 +191,8 @@ def _print_table( """ self._rule(title) formatted = tabulate( - table.items(), - headers=[key_name, value_name], + table, + headers=column_names, tablefmt="fancy_grid", numalign="right", ) @@ -250,7 +270,9 @@ def print_results( ) self.console.print("[bold magenta]Metrics:[/bold magenta]") for table_name, table in metrics.items(): - self._print_table(table_name, table) + self.print_table( + table_name, list(table.items()), ["Name", "Value"] + ) for matrix_name, matrix in matrices.get(table_name, {}).items(): self._print_matrix( self._format_matrix_title(matrix_name), matrix @@ -270,7 +292,12 @@ def print_results( self._log_console.print(f"Loss: {loss}") self._log_console.print("Metrics:") for table_name, table in metrics.items(): - self._print_table(table_name, table, console=self._log_console) + self.print_table( + table_name, + list(table.items()), + ["Name", "Value"], + console=self._log_console, + ) for matrix_name, matrix in matrices.get(table_name, {}).items(): self._print_matrix( self._format_matrix_title(matrix_name), @@ -293,12 +320,12 @@ def print_results( self._log_buffer.seek(0) self._log_buffer.truncate(0) - def _print_table( + @override + def print_table( self, title: str, - table: Mapping[str, int | str | float], - key_name: str = "Name", - value_name: str = "Value", + table: Iterable[tuple[str | int | float, ...]], + column_names: list[str], console: Console | None = None, ) -> None: """Prints table to the console using rich text. @@ -323,10 +350,18 @@ def _print_table( header_style="bold magenta", title_style="bold", ) - rich_table.add_column(key_name, style="magenta") - rich_table.add_column(value_name, style="white") - for name, value in table.items(): - rich_table.add_row(name, f"{value:.5f}") + for i, column_name in enumerate(column_names): + rich_table.add_column( + column_name, style="magenta" if i == 0 else "white" + ) + for name, *values in table: + rich_table.add_row( + str(name), + *[ + f"{value:.5f}" if isinstance(value, float) else str(value) + for value in values + ], + ) console.print(rich_table) def _print_matrix( diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 1f84bf6c..39dd1d32 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1,17 +1,16 @@ import threading from collections.abc import Mapping +from copy import deepcopy from pathlib import Path -from threading import ExceptHookArgs +from threading import ExceptHookArgs, Thread from typing import Literal, overload import lightning.pytorch as pl import lightning_utilities.core.rank_zero as rank_zero_module -import onnx import rich.traceback import torch import torch.utils.data as torch_data import yaml -from aimet_onnx.quantsim import QuantizationSimModel from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only from loguru import logger @@ -20,6 +19,7 @@ from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.typing import Params, PathType from luxonis_ml.utils import Environ, LuxonisFileSystem +from torch import nn from typeguard import typechecked from luxonis_train.callbacks import ( @@ -80,6 +80,7 @@ def __init__( cfg: PathType | Params | Config | None, opts: Params | list[str] | tuple[str, ...] | None = None, *, + weights: PathType | None = None, debug_mode: bool = False, dataset_metadata: DatasetMetadata | None = None, ): @@ -286,6 +287,16 @@ def __init__( _core=self, ) + if weights is not None: + weights = LuxonisFileSystem.download( + str(weights), self.run_save_dir + ) + if self.cfg.model.weights is not None: + logger.warning( + "Weights provided in the command line, but config weights are set. " + "Ignoring weights provided in config." + ) + self.lightning_module.load_checkpoint(weights) self._exported_models: dict[str, Path] = {} def _train(self, resume: PathType | None, *args, **kwargs) -> None: @@ -385,7 +396,7 @@ def export( weights: PathType | None = None, ignore_missing_weights: bool = False, ckpt_only: bool = False, - ) -> None: + ) -> Path: """Runs export. @type save_path: PathType | None @@ -406,6 +417,9 @@ def export( This is useful for updating the metadata in the checkpoint file in case they changed (e.g. new configuration file, architectural changes affecting the exection order etc.) + + @rtype: Path + @return: Path to the exported ONNX model file or .ckpt file if ckpt_only is True. """ weights = weights or self.cfg.model.weights @@ -438,7 +452,7 @@ def export( logger.info( f"Checkpoint saved to {export_path.with_suffix('.ckpt')}" ) - return + return export_path.with_suffix(".ckpt") with replace_weights(self.lightning_module, weights): onnx_kwargs = self.cfg.exporter.onnx.model_dump( @@ -476,7 +490,7 @@ def export( "Generating modelconverter config for a model " "with multiple inputs is not implemented yet." ) - return + return onnx_save_path inputs = [] outputs = [] @@ -518,6 +532,8 @@ def export( if self.cfg.exporter.upload_url is not None: # pragma: no cover LuxonisFileSystem.upload(f.name, self.cfg.exporter.upload_url) + return onnx_save_path + @overload def test( self, @@ -532,7 +548,7 @@ def test( new_thread: Literal[True] = ..., view: Literal["train", "test", "val"] = "test", weights: PathType | None = ..., - ) -> None: ... + ) -> Thread: ... @typechecked def test( @@ -540,22 +556,22 @@ def test( new_thread: bool = False, view: Literal["train", "val", "test"] = "test", weights: PathType | None = None, - ) -> Mapping[str, float] | None: + ) -> Mapping[str, float] | Thread: """Runs testing. @type new_thread: bool @param new_thread: Runs testing in a new thread if set to True. @type view: Literal["train", "test", "val"] @param view: Which view to run the testing on. Defauls to "test". - @rtype: Mapping[str, float] | None - @return: If new_thread is False, returns a dictionary test - results. @type weights: PathType | None @param weights: Path to the checkpoint from which to load weights. If not specified, the value of `model.weights` from the configuration file will be used. The current weights of the model will be temporarily replaced with the weights from the specified checkpoint. + @rtype: Mapping[str, float] | Thread + @return: If new_thread is False, returns a dictionary test + results. """ weights = weights or self.cfg.model.weights loader = self.pytorch_loaders[view] @@ -567,7 +583,8 @@ def test( args=(self.lightning_module, loader), daemon=True, ) - return self.thread.start() + self.thread.start() + return self.thread return self.pl_trainer.test(self.lightning_module, loader)[0] def infer( @@ -1141,11 +1158,59 @@ def convert( return archive_path, conversion_artifacts - def quantize( - self, onnx_path: PathType, mode: Literal["PTQ", "QAT"] - ) -> None: - model = onnx.load_model(onnx_path) - sim = QuantizationSimModel(model=model) + def quantize(self) -> None: + from aimet_torch import QuantizationSimModel + + model = self.lightning_module + save_dir = self.run_save_dir / "aimet" + save_dir.mkdir(parents=True, exist_ok=True) + pre_quant_test = self.test(view="val") + + inputs = { + input_name: torch.randn([1, *shape]).to(model.device) + for shapes in model.nodes.loader_input_shapes.values() + for input_name, shape in shapes.items() + } + + if len(inputs) > 1: + raise NotImplementedError( + "Quantization is not yet supported for models " + "with multiple inputs." + ) + input_names = list(inputs.keys()) + output_names = model._get_output_onnx_names(deepcopy(inputs)) + inputs = next(iter(inputs.values())) + + sim = QuantizationSimModel( + model=model, + dummy_input=inputs, + in_place=True, + ) + + def pass_calibration_data(model: nn.Module) -> None: + for imgs, _ in self.pytorch_loaders["val"]: + model.forward(imgs) + + sim.compute_encodings(pass_calibration_data) + + post_quant_test = self.test(view="val") + model.set_export_mode(mode=True) + sim.onnx.export( + inputs, + (save_dir / self.cfg.model.name).with_suffix(".onnx"), + input_names=input_names, + output_names=output_names, + ) + model.set_export_mode(mode=False) + table = [] + for key, value in pre_quant_test.items(): + log_key = key.replace("test/metric/", "").replace("test/loss/", "") + table.append((log_key, value, post_quant_test[key])) + model.progress_bar.print_table( + "Quantization results", + table, + ["Name", "Pre-quantization", "Post-quantization"], + ) @property def environ(self) -> Environ: diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 3e57a210..05e7ee2a 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -58,8 +58,8 @@ def prepare_and_infer_image( npy_img = model.loaders["val"].augment_test_image(images) torch_img = torch.tensor(npy_img).unsqueeze(0).permute(0, 3, 1, 2).float() - return model.lightning_module.forward( - {"image": torch_img}, + return model.lightning_module.run_forward_step( + {model.lightning_module.image_source: torch_img}, images=get_denormalized_images(model.cfg, torch_img), compute_visualizations=True, ) diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 61427d98..6f2a41bf 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -1,6 +1,7 @@ import re from collections import defaultdict from collections.abc import Callable, Mapping +from copy import deepcopy from pathlib import Path from typing import Any, Literal, cast @@ -187,8 +188,7 @@ def core(self) -> "luxonis_train.core.LuxonisModel": raise ValueError("Core reference is not set.") return self._core - @override - def forward( + def run_forward_step( self, inputs: dict[str, Tensor], labels: Labels | None = None, @@ -286,6 +286,45 @@ def forward( outputs=outputs_dict, losses=losses, visualizations=visualizations ) + @override + def forward( + self, inputs: dict[str, Tensor] | Tensor + ) -> tuple[Tensor, ...]: + """Forward pass of the model. + + @type inputs: L{Tensor} + @param inputs: Input tensors. + @rtype: dict[str, L{Packet}[L{Tensor}]] + @return: Output of the model. + """ + if isinstance(inputs, Tensor): + inputs = {self.image_source: inputs} + + outputs = self.run_forward_step( + inputs, + compute_loss=False, + compute_metrics=False, + compute_visualizations=False, + ).outputs + + output_order = sorted( + [ + (node_name, output_name, i) + for node_name, outs in outputs.items() + for output_name, out in outs.items() + for i in range(len(out)) + ] + ) + new_outputs = [] + for node_name, output_name, i in output_order: + node_output = outputs[node_name][output_name] + if isinstance(node_output, Tensor): + new_outputs.append(node_output) + else: + new_outputs.append(node_output[i]) + + return tuple(new_outputs) + def set_export_mode(self, *, mode: bool) -> None: for module in self.modules(): if isinstance(module, BaseNode): @@ -302,7 +341,7 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path: @rtype: Path @return: Path to the exported model. """ - device_before = self.device + device = self.device self.eval() self.to("cpu") # move to CPU to support deterministic .to_onnx() @@ -312,88 +351,13 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path: for shapes in self.nodes.loader_input_shapes.values() for input_name, shape in shapes.items() } - - inputs_deep_clone = { - k: torch.zeros(elem.shape).to(self.device) - for k, elem in inputs.items() - } - - inputs_for_onnx = {"inputs": inputs_deep_clone} + if "input_names" not in kwargs: + kwargs["input_names"] = list(inputs.keys()) self.set_export_mode(mode=True) - outputs = self.forward(inputs_deep_clone).outputs - output_order = sorted( - [ - (node_name, output_name, i) - for node_name, outs in outputs.items() - for output_name, out in outs.items() - for i in range(len(out)) - ] - ) - - output_counts = defaultdict(int) - for node_name, outs in outputs.items(): - output_counts[node_name] = sum(len(out) for out in outs.values()) - - export_output_names_dict = {} - for node_name, node in self.nodes.items(): - if node.module.export_output_names is not None: - if ( - len(node.module.export_output_names) - != output_counts[node_name] - ): - logger.warning( - f"Number of provided output names for node {node_name} " - f"({len(node.module.export_output_names)}) does not match " - f"number of outputs ({output_counts[node_name]}). " - f"Using default names." - ) - else: - export_output_names_dict[node_name] = ( - node.module.export_output_names - ) - - output_names = [] - # For cases where export_output_names should be used but - # output node's output is split into multiple subnodes - running_i = {} - for node_name, output_name, i in output_order: - if node_name in export_output_names_dict: - running_i[node_name] = ( - running_i.get(node_name, -1) + 1 - ) # if not present default to 0 otherwise add 1 - output_names.append( - export_output_names_dict[node_name][running_i[node_name]] - ) - else: - output_names.append( - f"{self.nodes[node_name].task_name}/{node_name}/{output_name}/{i}" - ) - - old_forward = self.forward - - def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]: - old_outputs = old_forward( - inputs, - None, - compute_loss=False, - compute_metrics=False, - compute_visualizations=False, - ).outputs - outputs = [] - for node_name, output_name, i in output_order: - node_output = old_outputs[node_name][output_name] - if isinstance(node_output, Tensor): - outputs.append(node_output) - else: - outputs.append(node_output[i]) - return tuple(outputs) - - self.forward = export_forward # type: ignore + output_names = self._get_output_onnx_names(deepcopy(inputs)) - if "input_names" not in kwargs: - kwargs["input_names"] = list(inputs.keys()) if "output_names" not in kwargs: kwargs["output_names"] = output_names @@ -401,15 +365,14 @@ def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]: # PyTorch 2.9 introduces a breaking change that # sets the default value to True kwargs.setdefault("dynamo", False) - self.to_onnx(save_path, inputs_for_onnx, **kwargs) - self.forward = old_forward # type: ignore + self.to_onnx(save_path, {"inputs": inputs}, **kwargs) logger.info(f"Model exported to {save_path}") self.set_export_mode(mode=False) self.train() - self.to(device_before) # reset device after export + self.to(device) # reset device after export return Path(save_path) @@ -417,7 +380,7 @@ def export_forward(inputs: dict[str, Tensor]) -> tuple[Tensor, ...]: def training_step( self, train_batch: tuple[dict[str, Tensor], Labels] ) -> Tensor: - outputs = self.forward(*train_batch) + outputs = self.run_forward_step(*train_batch) if not outputs.losses: raise ValueError("Losses are empty, check if you defined any loss") @@ -443,7 +406,7 @@ def predict_step( ) -> LuxonisOutput: inputs, labels = batch images = get_denormalized_images(self.cfg, inputs[self.image_source]) - return self.forward( + return self.run_forward_step( inputs, labels, images=images, @@ -624,6 +587,14 @@ def load_checkpoint(self, ckpt: PathType | dict[str, Any] | None) -> None: sub_state_dict, strict=False ) + def detach(self) -> None: + """Detaches the model from the trainer. + + This is useful when the model needs to be used outside of the + training loop, for example for inference or exporting. + """ + self.trainer = None + def _check_valid_epoch_counts(self, ckpt_config: dict) -> None: previous_trainer_cfg = ckpt_config.get("trainer", {}) previous_epochs = previous_trainer_cfg.get("epochs", None) @@ -655,7 +626,7 @@ def _evaluation_step( if self._n_logged_images < max_log_images: images = get_denormalized_images(self.cfg, input_image) - outputs = self.forward( + outputs = self.run_forward_step( inputs, labels, images=images, @@ -904,25 +875,22 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]: ) for callback in self.cfg.trainer.callbacks: + model_name = self.cfg.exporter.name or self.cfg.model.name if callback.name == "UploadCheckpoint": artifact_keys.update( {"best_val_metric.ckpt", "min_val_loss.ckpt"} ) elif callback.name == "ExportOnTrainEnd": - artifact_keys.add( - f"{self.cfg.exporter.name or self.cfg.model.name}.onnx" - ) + artifact_keys.add(f"{model_name}.onnx") elif callback.name == "ArchiveOnTrainEnd": - artifact_keys.add( - f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz" - ) + artifact_keys.add(f"{model_name}.onnx.tar.xz") elif callback.name == "ConvertOnTrainEnd": - artifact_keys.add( - f"{self.cfg.exporter.name or self.cfg.model.name}.onnx" - ) - artifact_keys.add( - f"{self.cfg.exporter.name or self.cfg.model.name}.onnx.tar.xz" - ) + artifact_keys.add(f"{model_name}.onnx") + artifact_keys.add(f"{model_name}.onnx.tar.xz") + elif callback.name == "AIMETCallback": + artifact_keys.add(f"{model_name}.onnx") + artifact_keys.add(f"{model_name}.onnx.data") + artifact_keys.add(f"{model_name}.encodings") elif callback.name == "TrainingProgressCallback": metric_keys.update( { @@ -945,6 +913,10 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]: "artifacts": sorted(artifact_keys), } + @override + def __getstate__(self): + return super().__getstate__() | {"_core": None} + def _get_node_order_mapping( self, node_name: str, old_order: list[str], new_order: list[str] ) -> dict[str, str]: @@ -971,3 +943,54 @@ def _get_node_order_mapping( def _strip_state_prefix(key: str) -> str: idx = 3 if "module." in key else 2 return ".".join(key.split(".")[idx:]) + + def _get_output_onnx_names(self, inputs: dict[str, Tensor]) -> list[str]: + outputs = self.run_forward_step(inputs).outputs + output_order = sorted( + [ + (node_name, output_name, i) + for node_name, outs in outputs.items() + for output_name, out in outs.items() + for i in range(len(out)) + ] + ) + + output_counts = defaultdict(int) + for node_name, outs in outputs.items(): + output_counts[node_name] = sum(len(out) for out in outs.values()) + + export_output_names_dict = {} + for node_name, node in self.nodes.items(): + if node.module.export_output_names is not None: + if ( + len(node.module.export_output_names) + != output_counts[node_name] + ): + logger.warning( + f"Number of provided output names for node {node_name} " + f"({len(node.module.export_output_names)}) does not match " + f"number of outputs ({output_counts[node_name]}). " + f"Using default names." + ) + else: + export_output_names_dict[node_name] = ( + node.module.export_output_names + ) + + output_names = [] + # For cases where export_output_names should be used but + # output node's output is split into multiple subnodes + running_i = {} + for node_name, output_name, i in output_order: + if node_name in export_output_names_dict: + running_i[node_name] = ( + running_i.get(node_name, -1) + 1 + ) # if not present default to 0 otherwise add 1 + output_names.append( + export_output_names_dict[node_name][running_i[node_name]] + ) + else: + output_names.append( + f"{self.nodes[node_name].task_name}/{node_name}/{output_name}/{i}" + ) + return output_names From 6e506ced5c08ee870c4ef0e61effc4d8b24bd1d2 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 29 Mar 2026 17:36:49 +0200 Subject: [PATCH 03/74] fix for pickle --- luxonis_train/lightning/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py index 1219e335..55852525 100644 --- a/luxonis_train/lightning/utils.py +++ b/luxonis_train/lightning/utils.py @@ -51,7 +51,7 @@ class MainMetric(NamedTuple): class LossAccumulator(defaultdict[str, float]): - def __init__(self): + def __init__(self, *args, **kwargs): super().__init__(float) self.counts = defaultdict(int) From 0e3673bf2e808bbd1ad021b5bf5000b14e1d7529 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 29 Mar 2026 17:54:22 +0200 Subject: [PATCH 04/74] ignoring quant modules --- luxonis_train/assigners/__init__.py | 9 +++++++++ luxonis_train/attached_modules/base_attached_module.py | 4 ++++ .../attached_modules/losses/adaptive_detection_loss.py | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/luxonis_train/assigners/__init__.py b/luxonis_train/assigners/__init__.py index 0b8b074a..6aa863c9 100644 --- a/luxonis_train/assigners/__init__.py +++ b/luxonis_train/assigners/__init__.py @@ -1,4 +1,13 @@ +from contextlib import suppress + from .atss_assigner import ATSSAssigner from .tal_assigner import TaskAlignedAssigner +with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(ATSSAssigner) + QuantizationMixin.ignore(TaskAlignedAssigner) + + __all__ = ["ATSSAssigner", "TaskAlignedAssigner"] diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 67487ddf..3edb8d19 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -81,6 +81,10 @@ def __init__(self, *, node: BaseNode | None = None, **kwargs): self._task = None self._check_node_type_override() + with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(self.__class__) @property def current_epoch(self) -> int: diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index e72ace80..bf5f85ba 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -1,3 +1,4 @@ +from contextlib import suppress from typing import Literal, cast import torch @@ -277,6 +278,10 @@ def __init__( self.alpha = alpha self.gamma = gamma self.per_class_weights = per_class_weights + with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(self.__class__) def forward( self, pred_score: Tensor, target_score: Tensor, label: Tensor From 9cd682672222a44a0369e32adefc043da384cf72 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 30 Mar 2026 16:01:26 +0200 Subject: [PATCH 05/74] ptq support --- luxonis_train/__main__.py | 3 +- .../losses/adaptive_detection_loss.py | 51 +++++++++++++------ luxonis_train/callbacks/aimet_callback.py | 8 ++- luxonis_train/core/core.py | 48 ++++++++++++++--- 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index 6e6d356d..3c9e6f8c 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -441,9 +441,10 @@ def quantize( *, config: str | None = None, weights: str | None = None, + epochs: int = 4, ): model = create_model(config, opts, weights=weights, debug_mode=True) - model.quantize() + model.quantize(epochs=epochs) @upgrade_app.command() diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index bf5f85ba..f10e1c99 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -30,6 +30,7 @@ class AdaptiveDetectionLoss(BaseLoss): n_anchors_list: list[int] stride_tensor: Tensor gt_bboxes_scale: Tensor + anchor_points_strided: Tensor def __init__( self, @@ -103,6 +104,19 @@ def __init__( self.class_loss_weight = class_loss_weight self.iou_loss_weight = iou_loss_weight + self.register_buffer( + "gt_bboxes_scale", + torch.tensor( + [ + self.original_img_size[1], + self.original_img_size[0], + self.original_img_size[1], + self.original_img_size[0], + ], + ), + persistent=False, + ) + self._logged_assigner_change = False def forward( @@ -164,21 +178,12 @@ def forward( return loss, sub_losses def _init_parameters(self, features: list[Tensor]) -> None: - if not hasattr(self, "gt_bboxes_scale"): - self.gt_bboxes_scale = torch.tensor( - [ - self.original_img_size[1], - self.original_img_size[0], - self.original_img_size[1], - self.original_img_size[0], - ], - device=features[0].device, - ) + if not hasattr(self, "anchors"): ( - self.anchors, - self.anchor_points, - self.n_anchors_list, - self.stride_tensor, + anchors, + anchor_points, + n_anchors_list, + stride_tensor, ) = anchors_for_fpn_features( features, self.stride, @@ -186,8 +191,22 @@ def _init_parameters(self, features: list[Tensor]) -> None: self.grid_cell_offset, multiply_with_stride=True, ) - self.anchor_points_strided = ( - self.anchor_points / self.stride_tensor + self.register_buffer("anchors", anchors, persistent=False) + self.register_buffer( + "anchor_points", anchor_points, persistent=False + ) + self.register_buffer( + "n_anchors_list", + torch.tensor(n_anchors_list), + persistent=False, + ) + self.register_buffer( + "stride_tensor", stride_tensor, persistent=False + ) + self.register_buffer( + "anchor_points_strided", + anchor_points / stride_tensor, + persistent=False, ) def _run_assigner( diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py index 09033dd4..288777f5 100644 --- a/luxonis_train/callbacks/aimet_callback.py +++ b/luxonis_train/callbacks/aimet_callback.py @@ -1,5 +1,3 @@ -from typing import Literal - import lightning.pytorch as pl import luxonis_train as lxt @@ -9,11 +7,11 @@ @CALLBACKS.register() class AIMETCallback(NeedsCheckpoint): - def __init__(self, mode: Literal["PTQ", "QAT"] = "PTQ"): + def __init__(self, epochs: int = 4): super().__init__() - self.mode = mode + self.epochs = epochs def on_train_end( self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" ) -> None: - pl_module.core.quantize() + pl_module.core.quantize(self.epochs) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 39dd1d32..2e09ca44 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -11,6 +11,7 @@ import torch import torch.utils.data as torch_data import yaml +from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only from loguru import logger @@ -19,7 +20,8 @@ from luxonis_ml.nn_archive.config import CONFIG_VERSION from luxonis_ml.typing import Params, PathType from luxonis_ml.utils import Environ, LuxonisFileSystem -from torch import nn +from rich.progress import track +from torch import Tensor, nn from typeguard import typechecked from luxonis_train.callbacks import ( @@ -224,7 +226,10 @@ def __init__( "Weighted sampler is not implemented yet." ) - self.pytorch_loaders: dict[View, torch_data.DataLoader] = {} + self.pytorch_loaders: dict[ + View, + torch_data.DataLoader[tuple[dict[str, Tensor], dict[str, Tensor]]], + ] = {} for view in ("train", "val", "test"): if self.cfg.trainer.n_validation_batches is not None and view in { "val", @@ -1158,7 +1163,7 @@ def convert( return archive_path, conversion_artifacts - def quantize(self) -> None: + def quantize(self, epochs: int = 4) -> None: from aimet_torch import QuantizationSimModel model = self.lightning_module @@ -1193,8 +1198,39 @@ def pass_calibration_data(model: nn.Module) -> None: sim.compute_encodings(pass_calibration_data) - post_quant_test = self.test(view="val") + ptq_test = self.test(view="val") + model.train() + if CUDAAccelerator.is_available(): + model.cuda() + model.automatic_optimization = False + + for e in track( + range(epochs), description="Running Quantization-Aware Training..." + ): + for imgs, labels in self.pytorch_loaders["train"]: + imgs = {k: v.to(model.device) for k, v in imgs.items()} + labels = {k: v.to(model.device) for k, v in labels.items()} + loss = model.training_step((imgs, labels)) + optimizers = model.optimizers() + schedulers = model.lr_schedulers() + if not isinstance(optimizers, list): + optimizers = [optimizers] + if not isinstance(schedulers, list): + schedulers = [schedulers] + for optimizer in optimizers: + optimizer.zero_grad() + model.manual_backward(loss) + for optimizer in optimizers: + optimizer.step() + for scheduler in schedulers: + if scheduler is not None: + scheduler.step(e) + + model.automatic_optimization = True + model.eval() + qat_test = self.test(view="val") model.set_export_mode(mode=True) + sim.onnx.export( inputs, (save_dir / self.cfg.model.name).with_suffix(".onnx"), @@ -1205,11 +1241,11 @@ def pass_calibration_data(model: nn.Module) -> None: table = [] for key, value in pre_quant_test.items(): log_key = key.replace("test/metric/", "").replace("test/loss/", "") - table.append((log_key, value, post_quant_test[key])) + table.append((log_key, value, ptq_test[key], qat_test[key])) model.progress_bar.print_table( "Quantization results", table, - ["Name", "Pre-quantization", "Post-quantization"], + ["Name", "Pre-Quant", "PTQ", "QAT"], ) @property From 3d8828a17a93cfadbaef6d8843973b4e75a80ffa Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 31 Mar 2026 16:49:28 +0200 Subject: [PATCH 06/74] track --- luxonis_train/core/core.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2e09ca44..4ea4016a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1193,12 +1193,18 @@ def quantize(self, epochs: int = 4) -> None: ) def pass_calibration_data(model: nn.Module) -> None: - for imgs, _ in self.pytorch_loaders["val"]: + for imgs, _ in track( + self.pytorch_loaders["val"], + description="Computing quantization encodings...", + total=len(self.pytorch_loaders["val"]), + ): + imgs = {k: v.to(model.device) for k, v in imgs.items()} model.forward(imgs) sim.compute_encodings(pass_calibration_data) ptq_test = self.test(view="val") + model.train() if CUDAAccelerator.is_available(): model.cuda() From 7dca296e97dd59eab043e338a6044473be1d78ff Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 31 Mar 2026 22:10:12 +0200 Subject: [PATCH 07/74] more params --- luxonis_train/callbacks/aimet_callback.py | 26 ++++++++++++-- luxonis_train/core/core.py | 44 +++++++++++++++++++++-- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py index 288777f5..74b4567c 100644 --- a/luxonis_train/callbacks/aimet_callback.py +++ b/luxonis_train/callbacks/aimet_callback.py @@ -1,4 +1,5 @@ import lightning.pytorch as pl +from aimet_torch.common.defs import QuantizationDataType, QuantScheme import luxonis_train as lxt from luxonis_train.callbacks.needs_checkpoint import NeedsCheckpoint @@ -7,11 +8,32 @@ @CALLBACKS.register() class AIMETCallback(NeedsCheckpoint): - def __init__(self, epochs: int = 4): + def __init__( + self, + epochs: int = 4, + quant_scheme: str | QuantScheme = QuantScheme.min_max, + default_output_bw: int = 8, + default_param_bw: int = 8, + config_file: str | None = None, + default_data_type: QuantizationDataType = QuantizationDataType.int, + ): super().__init__() self.epochs = epochs + self.quant_scheme = quant_scheme + self.default_output_bw = default_output_bw + self.default_param_bw = default_param_bw + self.config_file = config_file + self.default_data_type = default_data_type def on_train_end( self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" ) -> None: - pl_module.core.quantize(self.epochs) + pl_module.core.quantize( + self.get_checkpoint(pl_module), + self.epochs, + self.quant_scheme, + self.default_output_bw, + self.default_param_bw, + self.config_file, + self.default_data_type, + ) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 4ea4016a..2df19b2d 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -11,6 +11,8 @@ import torch import torch.utils.data as torch_data import yaml +from aimet_torch import QuantizationSimModel +from aimet_torch.common.defs import QuantizationDataType, QuantScheme from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only @@ -1163,10 +1165,43 @@ def convert( return archive_path, conversion_artifacts - def quantize(self, epochs: int = 4) -> None: - from aimet_torch import QuantizationSimModel + def quantize( + self, + weights: PathType | None = None, + epochs: int = 4, + quant_scheme: str | QuantScheme = QuantScheme.min_max, + default_output_bw: int = 8, + default_param_bw: int = 8, + config_file: str | None = None, + default_data_type: QuantizationDataType = QuantizationDataType.int, + ) -> None: + """Runs post-training quantization and quantization-aware + training using AIMET. + + @type weights: PathType | None + @param weights: Path to the checkpoint from which to load weights. + @type epochs: int + @param epochs: Number of epochs to run quantization-aware training for. + @type quant_scheme: str | QuantScheme + @param quant_scheme: Quantization scheme to use. Can be either a string + or an instance of `aimet_common.defs.QuantScheme`. If a string is + provided, it will be converted to the corresponding `QuantScheme` + instance. Defaults to `QuantScheme.min_max`. + @type default_output_bw: int + @param default_output_bw: Default bitwidth to use for quantizing outputs. + @type default_param_bw: int + @param default_param_bw: Default bitwidth to use for quantizing parameters. + @type config_file: str | None + @param config_file: Path to the AIMET config file specifying quantization settings for specific layers. If not provided, default quantization settings will be applied to all layers. + @type default_data_type: QuantizationDataType + @param default_data_type: Data type to use for quantization (e.g. integer or float). Defaults to `QuantizationDataType.int`. + """ model = self.lightning_module + + if weights is not None: + model.load_checkpoint(weights) + save_dir = self.run_save_dir / "aimet" save_dir.mkdir(parents=True, exist_ok=True) pre_quant_test = self.test(view="val") @@ -1189,6 +1224,11 @@ def quantize(self, epochs: int = 4) -> None: sim = QuantizationSimModel( model=model, dummy_input=inputs, + quant_scheme=quant_scheme, + default_output_bw=default_output_bw, + default_param_bw=default_param_bw, + config_file=config_file, + default_data_type=default_data_type, in_place=True, ) From 58565e4c9acba45fdd6373319f54043b08894392 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 31 Mar 2026 22:15:29 +0200 Subject: [PATCH 08/74] aimet config --- luxonis_train/callbacks/aimet_callback.py | 3 ++- luxonis_train/config/config.py | 26 +++++++++++++++++++++++ luxonis_train/lightning/utils.py | 14 ++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py index 74b4567c..aef7499d 100644 --- a/luxonis_train/callbacks/aimet_callback.py +++ b/luxonis_train/callbacks/aimet_callback.py @@ -16,8 +16,9 @@ def __init__( default_param_bw: int = 8, config_file: str | None = None, default_data_type: QuantizationDataType = QuantizationDataType.int, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.epochs = epochs self.quant_scheme = quant_scheme self.default_output_bw = default_output_bw diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 14aa9002..b48f3a2e 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -1,9 +1,11 @@ +import json import sys from collections.abc import Mapping from contextlib import suppress from pathlib import Path from typing import Annotated, Any, Literal, NamedTuple +from aimet_torch.common.defs import QuantizationDataType, QuantScheme from loguru import logger from luxonis_ml.enums import DatasetType from luxonis_ml.typing import ( @@ -726,6 +728,29 @@ def _validate_quantization_mode(value: str) -> str: return value +class AIMETConfig(BaseModelExtraForbid): + active: bool = False + epochs: PositiveInt = 4 + default_output_bw: Literal[4, 8, 16] = 8 + default_param_bw: Literal[4, 8, 16] = 8 + default_data_type: QuantizationDataType = QuantizationDataType.int + quant_scheme: QuantScheme = QuantScheme.min_max + config: Params | None = None + + @field_validator("config", mode="before") + @classmethod + def validate_config(cls, value: ParamValue) -> Any: + if isinstance(value, str): + try: + fs = LuxonisFileSystem(value) + return json.loads(fs.read_text("")) + except Exception as e: + raise ValueError( + f"Failed to load AIMET config from file '{value}': {e}" + ) from e + return value + + class ExportConfig(ArchiveConfig): name: str | None = None input_shape: list[int] | None = None @@ -742,6 +767,7 @@ class ExportConfig(ArchiveConfig): default_factory=BlobconverterExportConfig ) hubai: HubAIExportConfig = Field(default_factory=HubAIExportConfig) + aimet: AIMETConfig = Field(default_factory=AIMETConfig) @field_validator("scale_values", "mean_values", mode="before") @classmethod diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py index 55852525..30c8d6ce 100644 --- a/luxonis_train/lightning/utils.py +++ b/luxonis_train/lightning/utils.py @@ -519,6 +519,20 @@ def build_callbacks( "in the callbacks list. The `accumulate_grad_batches` " "parameter in the config will be ignored." ) + if cfg.exporter.aimet.active: + aimet_cfg = cfg.exporter.aimet + callbacks.append( + from_registry( + CALLBACKS, + "AIMETCallback", + epochs=aimet_cfg.epochs, + quant_scheme=aimet_cfg.quant_scheme, + default_output_bw=aimet_cfg.default_output_bw, + default_param_bw=aimet_cfg.default_param_bw, + default_data_type=aimet_cfg.default_data_type, + config_file=aimet_cfg.config, + ) + ) return callbacks From 40518c0e5a3eb9bd53e218785123b6e85e0c6a31 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 31 Mar 2026 22:16:07 +0200 Subject: [PATCH 09/74] requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 904977a2..da5f7e75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ torch<2.11 # temporary pin: GitHub runner image does not yet support the newer N torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.1 +aimet-torch~=2.27 From fa33a8db213af298b8b088d6d6894532284ac365 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 1 Apr 2026 16:31:37 +0200 Subject: [PATCH 10/74] serialization --- luxonis_train/config/config.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index b48f3a2e..fa784e02 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -2,6 +2,7 @@ import sys from collections.abc import Mapping from contextlib import suppress +from enum import Enum from pathlib import Path from typing import Annotated, Any, Literal, NamedTuple @@ -750,6 +751,12 @@ def validate_config(cls, value: ParamValue) -> Any: ) from e return value + @field_serializer("default_data_type", "quant_scheme") + def serialize_enums(self, value: Any) -> str: + if isinstance(value, Enum): + return value.name + return value + class ExportConfig(ArchiveConfig): name: str | None = None From c0511443471595aadb4756ba1b89ee70ea9fb822 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 1 Apr 2026 23:27:06 +0200 Subject: [PATCH 11/74] more ptq techniques --- .../attached_modules/base_attached_module.py | 7 +- .../attached_modules/metrics/base_metric.py | 6 + luxonis_train/core/core.py | 108 +++++++++++++--- luxonis_train/lightning/luxonis_lightning.py | 116 +++++++++++------- luxonis_train/loaders/base_loader.py | 18 +-- luxonis_train/loaders/luxonis_loader_torch.py | 8 +- luxonis_train/nodes/blocks/blocks.py | 2 +- luxonis_train/registry.py | 4 +- 8 files changed, 190 insertions(+), 79 deletions(-) diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 3edb8d19..5a5ca0e5 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -57,7 +57,7 @@ def __init__(self, *, node: BaseNode | None = None, **kwargs): @param kwargs: Additional keyword arguments. """ super().__init__(**kwargs) - self._node = node + self._node = (node,) if node is not None and node.task is not None: if ( @@ -124,12 +124,13 @@ def node(self) -> BaseNode: @raises RuntimeError: If the node was not provided during initialization. """ - if self._node is None: + node = self._node[0] + if node is None: raise RuntimeError( "Attempt to access `node` reference, but it was not " "provided during initialization." ) - return self._node + return node @property def n_keypoints(self) -> int: diff --git a/luxonis_train/attached_modules/metrics/base_metric.py b/luxonis_train/attached_modules/metrics/base_metric.py index cdfd6f16..da6bfb9b 100644 --- a/luxonis_train/attached_modules/metrics/base_metric.py +++ b/luxonis_train/attached_modules/metrics/base_metric.py @@ -154,6 +154,12 @@ def compute( """ return super().compute() + def __eq__(self, other: object) -> bool: + return id(self) == id(other) + + def __hash__(self) -> int: + return id(self) + @cached_property def _signature(self) -> dict[str, Parameter]: return get_signature(self.update) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 2df19b2d..c48fcfe2 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1,9 +1,10 @@ +import math import threading from collections.abc import Mapping from copy import deepcopy from pathlib import Path from threading import ExceptHookArgs, Thread -from typing import Literal, overload +from typing import Any, Literal, cast, overload import lightning.pytorch as pl import lightning_utilities.core.rank_zero as rank_zero_module @@ -12,7 +13,11 @@ import torch.utils.data as torch_data import yaml from aimet_torch import QuantizationSimModel +from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters +from aimet_torch.bn_reestimation import reestimate_bn_stats from aimet_torch.common.defs import QuantizationDataType, QuantScheme +from aimet_torch.cross_layer_equalization import equalize_model +from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only @@ -23,7 +28,7 @@ from luxonis_ml.typing import Params, PathType from luxonis_ml.utils import Environ, LuxonisFileSystem from rich.progress import track -from torch import Tensor, nn +from torch import nn from typeguard import typechecked from luxonis_train.callbacks import ( @@ -40,6 +45,7 @@ DebugLoader, LuxonisLoaderTorch, ) +from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput from luxonis_train.registry import LOADERS from luxonis_train.typing import View from luxonis_train.utils import ( @@ -230,7 +236,7 @@ def __init__( self.pytorch_loaders: dict[ View, - torch_data.DataLoader[tuple[dict[str, Tensor], dict[str, Tensor]]], + torch_data.DataLoader[LuxonisLoaderTorchOutput], ] = {} for view in ("train", "val", "test"): if self.cfg.trainer.n_validation_batches is not None and view in { @@ -1168,12 +1174,17 @@ def convert( def quantize( self, weights: PathType | None = None, - epochs: int = 4, + epochs: int = 20, quant_scheme: str | QuantScheme = QuantScheme.min_max, default_output_bw: int = 8, default_param_bw: int = 8, config_file: str | None = None, default_data_type: QuantizationDataType = QuantizationDataType.int, + adaround: bool = True, + adaround_iterations: int | None = None, + fold_batch_norms: bool = False, + cross_layer_equalization: bool = False, + batch_norm_reestimation: bool = False, ) -> None: """Runs post-training quantization and quantization-aware training using AIMET. @@ -1195,9 +1206,26 @@ def quantize( @param config_file: Path to the AIMET config file specifying quantization settings for specific layers. If not provided, default quantization settings will be applied to all layers. @type default_data_type: QuantizationDataType @param default_data_type: Data type to use for quantization (e.g. integer or float). Defaults to `QuantizationDataType.int`. + + @warning: Only in-place quantization is currently supported, + meaning that the original model will be modified and + exported ONNX model will be the quantized version of + the original model. Make sure to keep a backup of the + original model weights if you want to continue training + or exporting the original model after quantization. """ + def pass_calibration_data(model: nn.Module) -> None: + for imgs, _ in track( + self.pytorch_loaders["val"], + description="Computing quantization encodings...", + total=len(self.pytorch_loaders["val"]), + ): + model.forward(imgs) + model = self.lightning_module + model.reparametrize() + loader = self.pytorch_loaders["val"] if weights is not None: model.load_checkpoint(weights) @@ -1221,6 +1249,41 @@ def quantize( output_names = model._get_output_onnx_names(deepcopy(inputs)) inputs = next(iter(inputs.values())) + if CUDAAccelerator.is_available(): + inputs = inputs.cuda() + model.cuda() + + if fold_batch_norms and not batch_norm_reestimation: + logger.info("Folding batch norms into preceding layers") + fold_all_batch_norms( + model, input_shapes=inputs.shape, dummy_input=inputs + ) + if cross_layer_equalization: + logger.info("Applying cross-layer equalization") + equalize_model( + model, input_shapes=inputs.shape, dummy_input=inputs + ) + + if adaround: + ada_params = AdaroundParameters( + data_loader=loader, + num_batches=min( + len(loader), + math.ceil(2000 / self.cfg.trainer.batch_size), + ), + default_num_iterations=adaround_iterations, # type: ignore + ) + model = cast( + LuxonisLightningModule, + Adaround.apply_adaround( + model, + inputs, + ada_params, + path=str(save_dir), + filename_prefix="adaround", + ), + ) + sim = QuantizationSimModel( model=model, dummy_input=inputs, @@ -1231,19 +1294,16 @@ def quantize( default_data_type=default_data_type, in_place=True, ) + model = cast(LuxonisLightningModule, sim.model) - def pass_calibration_data(model: nn.Module) -> None: - for imgs, _ in track( - self.pytorch_loaders["val"], - description="Computing quantization encodings...", - total=len(self.pytorch_loaders["val"]), - ): - imgs = {k: v.to(model.device) for k, v in imgs.items()} - model.forward(imgs) + if adaround: + sim.set_and_freeze_param_encodings( + str(save_dir / "adaround.encodings") + ) sim.compute_encodings(pass_calibration_data) - ptq_test = self.test(view="val") + ptq_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] model.train() if CUDAAccelerator.is_available(): @@ -1254,8 +1314,6 @@ def pass_calibration_data(model: nn.Module) -> None: range(epochs), description="Running Quantization-Aware Training..." ): for imgs, labels in self.pytorch_loaders["train"]: - imgs = {k: v.to(model.device) for k, v in imgs.items()} - labels = {k: v.to(model.device) for k, v in labels.items()} loss = model.training_step((imgs, labels)) optimizers = model.optimizers() schedulers = model.lr_schedulers() @@ -1272,9 +1330,27 @@ def pass_calibration_data(model: nn.Module) -> None: if scheduler is not None: scheduler.step(e) + if batch_norm_reestimation: + logger.info("Reestimating batch norm statistics") + + def _forward_pass( + model: nn.Module, inputs: LuxonisLoaderTorchOutput + ) -> Any: + return model(inputs[0]) + + reestimate_bn_stats( + model, self.pytorch_loaders["train"], forward_fn=_forward_pass + ) + + if fold_batch_norms: + logger.info("Folding batch norms into preceding layers") + fold_all_batch_norms( + model, input_shapes=inputs.shape, dummy_input=inputs + ) + model.automatic_optimization = True model.eval() - qat_test = self.test(view="val") + qat_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] model.set_export_mode(mode=True) sim.onnx.export( diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 404ed3bd..3f66b519 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -25,6 +25,8 @@ from luxonis_train.callbacks import BaseLuxonisProgressBar from luxonis_train.config import Config from luxonis_train.nodes import BaseNode +from luxonis_train.nodes.blocks.reparametrizable import Reparametrizable +from luxonis_train.registry import _INTERNAL from luxonis_train.typing import Labels, Packet from luxonis_train.utils import DatasetMetadata, LuxonisTrackerPL @@ -188,9 +190,45 @@ def core(self) -> "luxonis_train.core.LuxonisModel": raise ValueError("Core reference is not set.") return self._core - def run_forward_step( + @override + def forward( + self, inputs: dict[str, Tensor] | Tensor + ) -> tuple[Tensor, ...]: + """Forward pass of the model. + + @type inputs: L{Tensor} + @param inputs: Input tensors. + @rtype: dict[str, L{Packet}[L{Tensor}]] + @return: Output of the model. + """ + outputs = self._run_forward( + inputs, + compute_loss=False, + compute_metrics=False, + compute_visualizations=False, + ).outputs + + output_order = sorted( + [ + (node_name, output_name, i) + for node_name, outs in outputs.items() + for output_name, out in outs.items() + for i in range(len(out)) + ] + ) + new_outputs = [] + for node_name, output_name, i in output_order: + node_output = outputs[node_name][output_name] + if isinstance(node_output, Tensor): + new_outputs.append(node_output) + else: + new_outputs.append(node_output[i]) + + return tuple(new_outputs) + + def _run_forward( self, - inputs: dict[str, Tensor], + inputs: dict[str, Tensor] | Tensor, labels: Labels | None = None, images: Tensor | None = None, *, @@ -224,6 +262,12 @@ def run_forward_step( @rtype: L{LuxonisOutput} @return: Output of the model. """ + if isinstance(inputs, Tensor): + inputs = {self.image_source: inputs} + + inputs = {k: v.to(self.device) for k, v in inputs.items()} + if labels is not None: + labels = {k: v.to(self.device) for k, v in labels.items()} losses: dict[ str, dict[str, Tensor | tuple[Tensor, dict[str, Tensor]]] ] = defaultdict(dict) @@ -286,50 +330,16 @@ def run_forward_step( outputs=outputs_dict, losses=losses, visualizations=visualizations ) - @override - def forward( - self, inputs: dict[str, Tensor] | Tensor - ) -> tuple[Tensor, ...]: - """Forward pass of the model. - - @type inputs: L{Tensor} - @param inputs: Input tensors. - @rtype: dict[str, L{Packet}[L{Tensor}]] - @return: Output of the model. - """ - if isinstance(inputs, Tensor): - inputs = {self.image_source: inputs} - - outputs = self.run_forward_step( - inputs, - compute_loss=False, - compute_metrics=False, - compute_visualizations=False, - ).outputs - - output_order = sorted( - [ - (node_name, output_name, i) - for node_name, outs in outputs.items() - for output_name, out in outs.items() - for i in range(len(out)) - ] - ) - new_outputs = [] - for node_name, output_name, i in output_order: - node_output = outputs[node_name][output_name] - if isinstance(node_output, Tensor): - new_outputs.append(node_output) - else: - new_outputs.append(node_output[i]) - - return tuple(new_outputs) - def set_export_mode(self, *, mode: bool) -> None: for module in self.modules(): if isinstance(module, BaseNode): module.set_export_mode(mode=mode) + def reparametrize(self) -> None: + for module in self.modules(): + if isinstance(module, Reparametrizable): + module.reparametrize() + def export_onnx(self, save_path: PathType, **kwargs) -> Path: """Exports the model to ONNX format. @@ -380,7 +390,7 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path: def training_step( self, train_batch: tuple[dict[str, Tensor], Labels] ) -> Tensor: - outputs = self.run_forward_step(*train_batch) + outputs = self._run_forward(*train_batch) if not outputs.losses: raise ValueError("Losses are empty, check if you defined any loss") @@ -406,7 +416,7 @@ def predict_step( ) -> LuxonisOutput: inputs, labels = batch images = get_denormalized_images(self.cfg, inputs[self.image_source]) - return self.run_forward_step( + return self._run_forward( inputs, labels, images=images, @@ -612,10 +622,12 @@ def _check_valid_epoch_counts(self, ckpt_config: dict) -> None: def _evaluation_step( self, mode: Literal["test", "val"], - inputs: dict[str, Tensor], + inputs: dict[str, Tensor] | Tensor, labels: Labels, ) -> dict[str, Tensor]: max_log_images = self.cfg.trainer.n_log_images + if isinstance(inputs, Tensor): + inputs = {self.image_source: inputs} input_image = inputs[self.image_source] # Smart logging is decided based on the classification task keys that are merged for all tasks @@ -626,7 +638,7 @@ def _evaluation_step( if self._n_logged_images < max_log_images: images = get_denormalized_images(self.cfg, input_image) - outputs = self.run_forward_step( + outputs = self._run_forward( inputs, labels, images=images, @@ -919,7 +931,17 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]: @override def __getstate__(self): - return super().__getstate__() | {"_core": None} + state = super().__getstate__() + state["_core"] = None + _INTERNAL["trainer"] = self._trainer + _INTERNAL["core"] = self._core + return state + + @override + def __setstate__(self, state: Any): + super().__setstate__(state) + self._trainer = _INTERNAL.get("trainer") # type: ignore + self._core = _INTERNAL.get("core") def _get_node_order_mapping( self, node_name: str, old_order: list[str], new_order: list[str] @@ -949,7 +971,7 @@ def _strip_state_prefix(key: str) -> str: return ".".join(key.split(".")[idx:]) def _get_output_onnx_names(self, inputs: dict[str, Tensor]) -> list[str]: - outputs = self.run_forward_step(inputs).outputs + outputs = self._run_forward(inputs).outputs output_order = sorted( [ (node_name, output_name, i) diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index f2176b70..276a1856 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -16,7 +16,7 @@ from luxonis_train.typing import Labels from luxonis_train.utils.general import get_attribute_check_none -LuxonisLoaderTorchOutput = tuple[dict[str, Tensor], Labels] +LuxonisLoaderTorchOutput = tuple[dict[str, Tensor] | Tensor, Labels] class BaseLoaderTorch( @@ -225,10 +225,7 @@ def augment_test_image(self, img: dict[str, Tensor]) -> Tensor: ) def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - img, labels = self.get(idx) - if isinstance(img, Tensor): - img = {self.image_source: img} - return img, labels + return self.get(idx) @abstractmethod def __len__(self) -> int: @@ -342,13 +339,16 @@ def collate_fn( @return: Tuple of inputs and annotations in the format expected by the model. """ - inputs: tuple[dict[str, Tensor], ...] + inputs: tuple[dict[str, Tensor], ...] | tuple[Tensor, ...] labels: tuple[Labels, ...] inputs, labels = zip(*batch, strict=True) - out_inputs = { - k: torch.stack([i[k] for i in inputs], 0) for k in inputs[0] - } + if isinstance(inputs[0], dict): + out_inputs = { + k: torch.stack([i[k] for i in inputs], 0) for k in inputs[0] + } + else: + out_inputs = torch.stack(inputs, 0) out_labels: Labels = {} diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index 673d9611..e1d39cce 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -151,11 +151,13 @@ def __len__(self) -> int: @property @override def input_shapes(self) -> dict[str, Size]: - img = self[0][0][self.image_source] + img = self[0][0] + if isinstance(img, dict): + img = img[self.image_source] return {self.image_source: img.shape} @override - def get(self, idx: int) -> tuple[dict[str, Tensor], Labels]: + def get(self, idx: int) -> tuple[dict[str, Tensor] | Tensor, Labels]: img, labels = self.loader[idx] if isinstance(img, np.ndarray): img = {self.image_source: img} @@ -164,6 +166,8 @@ def get(self, idx: int) -> tuple[dict[str, Tensor], Labels]: labels = self._remap_keypoints(labels) img = {k: self.img_numpy_to_torch(v) for k, v in img.items()} + if len(img) == 1: + img = next(iter(img.values())) return img, self.dict_numpy_to_torch(labels) diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index 79e27a5d..a7722e1f 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -467,7 +467,7 @@ def name(self) -> str: @override def reparametrize(self) -> None: if self.fused_branch is not None: - raise RuntimeError(f"{self.name} is already reparametrized") + return kernel, bias = self._fuse_parameters() fused_branch = nn.Conv2d( diff --git a/luxonis_train/registry.py b/luxonis_train/registry.py index 258679e5..f1c60bdf 100644 --- a/luxonis_train/registry.py +++ b/luxonis_train/registry.py @@ -1,7 +1,7 @@ """This module implements a metaclass for automatic registration of classes.""" -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from luxonis_ml.utils.registry import Registry @@ -34,6 +34,8 @@ VISUALIZERS: Registry[type["lxt.BaseVisualizer"]] = Registry("visualizers") +_INTERNAL: dict[str, Any] = {} + T = TypeVar("T") From 4923c8ebbe54892d4e169c9dbc9353305a964f4f Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 2 Apr 2026 16:42:35 +0200 Subject: [PATCH 12/74] small fixes --- .../base_predefined_model.py | 2 +- luxonis_train/loaders/base_loader.py | 25 ++++++++++--------- luxonis_train/loaders/luxonis_loader_torch.py | 4 ++- .../loaders/luxonis_perlin_loader_torch.py | 2 +- luxonis_train/utils/__init__.py | 9 ------- 5 files changed, 18 insertions(+), 24 deletions(-) diff --git a/luxonis_train/config/predefined_models/base_predefined_model.py b/luxonis_train/config/predefined_models/base_predefined_model.py index aadefa0b..626ef4c1 100644 --- a/luxonis_train/config/predefined_models/base_predefined_model.py +++ b/luxonis_train/config/predefined_models/base_predefined_model.py @@ -115,7 +115,7 @@ def __init__( self._metrics = ( [metrics] if isinstance(metrics, str) else metrics or [] ) - if main_metric is None: + if main_metric is None and self._metrics: if len(self._metrics) == 1: main_metric = self._metrics[0] else: diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py index 276a1856..f7c23f33 100644 --- a/luxonis_train/loaders/base_loader.py +++ b/luxonis_train/loaders/base_loader.py @@ -224,16 +224,8 @@ def augment_test_image(self, img: dict[str, Tensor]) -> Tensor: "`augment_test_image` method to expose this functionality." ) - def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: - return self.get(idx) - - @abstractmethod - def __len__(self) -> int: - """Returns length of the dataset.""" - ... - @abstractmethod - def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]: + def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: """Loads sample from dataset. @type idx: int @@ -243,6 +235,11 @@ def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]: """ ... + @abstractmethod + def __len__(self) -> int: + """Returns length of the dataset.""" + ... + @abstractmethod def get_classes(self) -> dict[str, dict[str, int]]: """Gets classes according to computer vision task. @@ -328,7 +325,7 @@ def img_numpy_to_torch(img: np.ndarray) -> Tensor: def collate_fn( self, batch: list[LuxonisLoaderTorchOutput], - ) -> tuple[dict[str, Tensor], Labels]: + ) -> tuple[dict[str, Tensor] | Tensor, Labels]: """Default collate function used for training. @type batch: list[LuxonisLoaderTorchOutput] @@ -345,10 +342,14 @@ def collate_fn( if isinstance(inputs[0], dict): out_inputs = { - k: torch.stack([i[k] for i in inputs], 0) for k in inputs[0] + k: torch.stack( + [i[k] for i in inputs], # type: ignore + 0, + ) + for k in inputs[0] } else: - out_inputs = torch.stack(inputs, 0) + out_inputs = torch.stack(inputs, 0) # type: ignore out_labels: Labels = {} diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py index e1d39cce..540ac125 100644 --- a/luxonis_train/loaders/luxonis_loader_torch.py +++ b/luxonis_train/loaders/luxonis_loader_torch.py @@ -157,7 +157,9 @@ def input_shapes(self) -> dict[str, Size]: return {self.image_source: img.shape} @override - def get(self, idx: int) -> tuple[dict[str, Tensor] | Tensor, Labels]: + def __getitem__( + self, idx: int + ) -> tuple[dict[str, Tensor] | Tensor, Labels]: img, labels = self.loader[idx] if isinstance(img, np.ndarray): img = {self.image_source: img} diff --git a/luxonis_train/loaders/luxonis_perlin_loader_torch.py b/luxonis_train/loaders/luxonis_perlin_loader_torch.py index 27c43b88..867f6def 100644 --- a/luxonis_train/loaders/luxonis_perlin_loader_torch.py +++ b/luxonis_train/loaders/luxonis_perlin_loader_torch.py @@ -78,7 +78,7 @@ def __init__( self.augmentations = self.loader.augmentations @override - def get(self, idx: int) -> tuple[Tensor, Labels]: + def __getitem__(self, idx: int) -> tuple[Tensor, Labels]: with _freeze_seed(): img, labels = self.loader[idx] if isinstance(img, dict): diff --git a/luxonis_train/utils/__init__.py b/luxonis_train/utils/__init__.py index 19967a57..be4b5f3f 100644 --- a/luxonis_train/utils/__init__.py +++ b/luxonis_train/utils/__init__.py @@ -57,27 +57,18 @@ "dist2bbox", "get_attribute_check_none", "get_batch_instances", - "get_batch_instances", "get_center_keypoints", "get_sigmas", "get_with_default", "infer_upscale_factor", "insert_class", "instances_from_batch", - "instances_from_batch", - "instances_from_batch", - "instances_from_batch", "keypoints_to_bboxes", "make_divisible", - "make_divisible", "non_max_suppression", - "non_max_suppression", - "safe_download", "safe_download", "seg_output_to_bool", "setup_logging", - "setup_logging", - "to_shape_packet", "to_shape_packet", "transform_boxes", "transform_keypoints", From cb07292b9b74039a0b30dac153be5d69c268340b Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 5 Apr 2026 15:17:25 +0200 Subject: [PATCH 13/74] fix keypoint loss --- .../attached_modules/losses/efficient_keypoint_bbox_loss.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index d535a4af..3e27c684 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from loguru import logger -from torch import Tensor +from torch import Tensor, nn from luxonis_train.attached_modules.losses import AdaptiveDetectionLoss from luxonis_train.nodes import EfficientKeypointBBoxHead @@ -17,8 +17,6 @@ from luxonis_train.utils.boundingbox import IoUType from luxonis_train.utils.keypoints import insert_class -from .bce_with_logits import BCEWithLogitsLoss - class EfficientKeypointBBoxLoss(AdaptiveDetectionLoss): node: EfficientKeypointBBoxHead @@ -74,7 +72,7 @@ def __init__( **kwargs, ) - self.b_cross_entropy = BCEWithLogitsLoss( + self.b_cross_entropy = nn.BCEWithLogitsLoss( pos_weight=torch.tensor([viz_pw]) ) self.sigmas = get_sigmas( From 9dd61d8d1cc2fa1c1712470b1052bb5cb1291602 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 5 Apr 2026 15:17:37 +0200 Subject: [PATCH 14/74] optimizer --- luxonis_train/core/core.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index c48fcfe2..8527363a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -29,6 +29,8 @@ from luxonis_ml.utils import Environ, LuxonisFileSystem from rich.progress import track from torch import nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler, StepLR from typeguard import typechecked from luxonis_train.callbacks import ( @@ -46,7 +48,7 @@ LuxonisLoaderTorch, ) from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput -from luxonis_train.registry import LOADERS +from luxonis_train.registry import LOADERS, OPTIMIZERS, from_registry from luxonis_train.typing import View from luxonis_train.utils import ( DatasetMetadata, @@ -1180,11 +1182,13 @@ def quantize( default_param_bw: int = 8, config_file: str | None = None, default_data_type: QuantizationDataType = QuantizationDataType.int, - adaround: bool = True, + adaround: bool = False, adaround_iterations: int | None = None, fold_batch_norms: bool = False, cross_layer_equalization: bool = False, batch_norm_reestimation: bool = False, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, ) -> None: """Runs post-training quantization and quantization-aware training using AIMET. @@ -1223,6 +1227,7 @@ def pass_calibration_data(model: nn.Module) -> None: ): model.forward(imgs) + cfg = self.cfg.exporter.aimet model = self.lightning_module model.reparametrize() loader = self.pytorch_loaders["val"] @@ -1310,25 +1315,27 @@ def pass_calibration_data(model: nn.Module) -> None: model.cuda() model.automatic_optimization = False + if optimizer is None: + opt_cfg = cfg.optimizer or self.cfg.trainer.optimizer + optimizer = from_registry( + OPTIMIZERS, + opt_cfg.name, + params=model.parameters(), + **opt_cfg.params, + ) + + if scheduler is None: + scheduler = StepLR(optimizer, step_size=5, gamma=0.1) + for e in track( range(epochs), description="Running Quantization-Aware Training..." ): for imgs, labels in self.pytorch_loaders["train"]: + optimizer.zero_grad() loss = model.training_step((imgs, labels)) - optimizers = model.optimizers() - schedulers = model.lr_schedulers() - if not isinstance(optimizers, list): - optimizers = [optimizers] - if not isinstance(schedulers, list): - schedulers = [schedulers] - for optimizer in optimizers: - optimizer.zero_grad() model.manual_backward(loss) - for optimizer in optimizers: - optimizer.step() - for scheduler in schedulers: - if scheduler is not None: - scheduler.step(e) + optimizer.step() + scheduler.step(e) if batch_norm_reestimation: logger.info("Reestimating batch norm statistics") From d45abc411539d2fb7f941519af647cf4adb8916a Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Sun, 5 Apr 2026 15:17:42 +0200 Subject: [PATCH 15/74] config options --- luxonis_train/config/config.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index fa784e02..c0530e5e 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -729,14 +729,28 @@ def _validate_quantization_mode(value: str) -> str: return value +class AdaroundConfig(BaseModelExtraForbid): + active: bool = False + default_num_iterations: PositiveInt | None = None + default_reg_param: float = 0.01 + default_beta_range: tuple[int, int] = (20, 2) + default_warm_start: float = 0.2 + + class AIMETConfig(BaseModelExtraForbid): active: bool = False - epochs: PositiveInt = 4 + epochs: PositiveInt = 20 default_output_bw: Literal[4, 8, 16] = 8 default_param_bw: Literal[4, 8, 16] = 8 default_data_type: QuantizationDataType = QuantizationDataType.int quant_scheme: QuantScheme = QuantScheme.min_max config: Params | None = None + fold_batch_norms: bool = False + cross_layer_equalization: bool = False + batch_norm_reestimation: bool = False + adaround: AdaroundConfig = Field(default_factory=AdaroundConfig) + optimizer: ConfigItem | None = None + scheduler: ConfigItem | None = None @field_validator("config", mode="before") @classmethod From 22c98a119454b047822ae83097c89622a20e7ab2 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 07:33:53 +0200 Subject: [PATCH 16/74] loading config values --- luxonis_train/core/core.py | 64 +++++++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 8527363a..cddbbe48 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1,3 +1,4 @@ +import json import math import threading from collections.abc import Mapping @@ -1176,17 +1177,20 @@ def convert( def quantize( self, weights: PathType | None = None, - epochs: int = 20, - quant_scheme: str | QuantScheme = QuantScheme.min_max, - default_output_bw: int = 8, - default_param_bw: int = 8, + epochs: int | None = None, + quant_scheme: str | QuantScheme | None = None, + default_output_bw: int | None = None, + default_param_bw: int | None = None, config_file: str | None = None, - default_data_type: QuantizationDataType = QuantizationDataType.int, - adaround: bool = False, + default_data_type: QuantizationDataType | None = None, + adaround: bool | None = None, adaround_iterations: int | None = None, - fold_batch_norms: bool = False, - cross_layer_equalization: bool = False, - batch_norm_reestimation: bool = False, + adaround_reg_param: float | None = None, + adaround_beta_range: tuple[int, int] | None = None, + adaround_warm_start: float | None = None, + fold_batch_norms: bool | None = None, + cross_layer_equalization: bool | None = None, + batch_norm_reestimation: bool | None = None, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, ) -> None: @@ -1227,7 +1231,47 @@ def pass_calibration_data(model: nn.Module) -> None: ): model.forward(imgs) + save_dir = self.run_save_dir / "aimet" + save_dir.mkdir(parents=True, exist_ok=True) + cfg = self.cfg.exporter.aimet + epochs = epochs or cfg.epochs + quant_scheme = quant_scheme or cfg.quant_scheme + if isinstance(quant_scheme, str): + quant_scheme = QuantScheme.from_str(quant_scheme) + default_output_bw = default_output_bw or cfg.default_output_bw + default_param_bw = default_param_bw or cfg.default_param_bw + default_data_type = default_data_type or cfg.default_data_type + adaround = adaround if adaround is not None else cfg.adaround.active + adaround_iterations = ( + adaround_iterations or cfg.adaround.default_num_iterations + ) + adaround_reg_param = ( + adaround_reg_param or cfg.adaround.default_reg_param + ) + adaround_beta_range = ( + adaround_beta_range or cfg.adaround.default_beta_range + ) + adaround_warm_start = ( + adaround_warm_start or cfg.adaround.default_warm_start + ) + aimet_config_file = config_file or cfg.config + if isinstance(aimet_config_file, dict): + with open(save_dir / "aimet_config.json", "w") as f: + json.dump(aimet_config_file, f, indent=4) + aimet_config_file = str(save_dir / "aimet_config.json") + + fold_batch_norms = ( + fold_batch_norms + if fold_batch_norms is not None + else cfg.fold_batch_norms + ) + cross_layer_equalization = ( + cross_layer_equalization + if cross_layer_equalization is not None + else cfg.cross_layer_equalization + ) + model = self.lightning_module model.reparametrize() loader = self.pytorch_loaders["val"] @@ -1235,8 +1279,6 @@ def pass_calibration_data(model: nn.Module) -> None: if weights is not None: model.load_checkpoint(weights) - save_dir = self.run_save_dir / "aimet" - save_dir.mkdir(parents=True, exist_ok=True) pre_quant_test = self.test(view="val") inputs = { From eafbcf9b04e0d9c65c800352b10c0f0959611b58 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 09:27:50 +0200 Subject: [PATCH 17/74] cleaned up --- luxonis_train/callbacks/aimet_callback.py | 28 +-- luxonis_train/core/core.py | 243 +++++++--------------- luxonis_train/core/utils/aimet_utils.py | 155 ++++++++++++++ luxonis_train/lightning/utils.py | 15 +- 4 files changed, 233 insertions(+), 208 deletions(-) create mode 100644 luxonis_train/core/utils/aimet_utils.py diff --git a/luxonis_train/callbacks/aimet_callback.py b/luxonis_train/callbacks/aimet_callback.py index aef7499d..99d59bd1 100644 --- a/luxonis_train/callbacks/aimet_callback.py +++ b/luxonis_train/callbacks/aimet_callback.py @@ -1,5 +1,4 @@ import lightning.pytorch as pl -from aimet_torch.common.defs import QuantizationDataType, QuantScheme import luxonis_train as lxt from luxonis_train.callbacks.needs_checkpoint import NeedsCheckpoint @@ -8,33 +7,10 @@ @CALLBACKS.register() class AIMETCallback(NeedsCheckpoint): - def __init__( - self, - epochs: int = 4, - quant_scheme: str | QuantScheme = QuantScheme.min_max, - default_output_bw: int = 8, - default_param_bw: int = 8, - config_file: str | None = None, - default_data_type: QuantizationDataType = QuantizationDataType.int, - **kwargs, - ): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.epochs = epochs - self.quant_scheme = quant_scheme - self.default_output_bw = default_output_bw - self.default_param_bw = default_param_bw - self.config_file = config_file - self.default_data_type = default_data_type def on_train_end( self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" ) -> None: - pl_module.core.quantize( - self.get_checkpoint(pl_module), - self.epochs, - self.quant_scheme, - self.default_output_bw, - self.default_param_bw, - self.config_file, - self.default_data_type, - ) + pl_module.core.quantize(self.get_checkpoint(pl_module)) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index cddbbe48..add9e613 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1,11 +1,10 @@ import json -import math import threading from collections.abc import Mapping from copy import deepcopy from pathlib import Path from threading import ExceptHookArgs, Thread -from typing import Any, Literal, cast, overload +from typing import Literal, overload import lightning.pytorch as pl import lightning_utilities.core.rank_zero as rank_zero_module @@ -13,25 +12,17 @@ import torch import torch.utils.data as torch_data import yaml -from aimet_torch import QuantizationSimModel -from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters -from aimet_torch.bn_reestimation import reestimate_bn_stats from aimet_torch.common.defs import QuantizationDataType, QuantScheme -from aimet_torch.cross_layer_equalization import equalize_model -from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms -from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only from loguru import logger from luxonis_ml.data import LuxonisDataset from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION -from luxonis_ml.typing import Params, PathType +from luxonis_ml.typing import ConfigItem, Params, PathType from luxonis_ml.utils import Environ, LuxonisFileSystem -from rich.progress import track -from torch import nn from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler, StepLR +from torch.optim.lr_scheduler import LRScheduler from typeguard import typechecked from luxonis_train.callbacks import ( @@ -49,7 +40,12 @@ LuxonisLoaderTorch, ) from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput -from luxonis_train.registry import LOADERS, OPTIMIZERS, from_registry +from luxonis_train.registry import ( + LOADERS, + OPTIMIZERS, + SCHEDULERS, + from_registry, +) from luxonis_train.typing import View from luxonis_train.utils import ( DatasetMetadata, @@ -57,6 +53,10 @@ setup_logging, ) +from .utils.aimet_utils import ( + post_training_quantization, + quantization_aware_training, +) from .utils.annotate_utils import annotate_from_directory from .utils.archive_utils import ( get_head_configs, @@ -1194,73 +1194,19 @@ def quantize( optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, ) -> None: - """Runs post-training quantization and quantization-aware - training using AIMET. - - @type weights: PathType | None - @param weights: Path to the checkpoint from which to load weights. - @type epochs: int - @param epochs: Number of epochs to run quantization-aware training for. - @type quant_scheme: str | QuantScheme - @param quant_scheme: Quantization scheme to use. Can be either a string - or an instance of `aimet_common.defs.QuantScheme`. If a string is - provided, it will be converted to the corresponding `QuantScheme` - instance. Defaults to `QuantScheme.min_max`. - @type default_output_bw: int - @param default_output_bw: Default bitwidth to use for quantizing outputs. - @type default_param_bw: int - @param default_param_bw: Default bitwidth to use for quantizing parameters. - @type config_file: str | None - @param config_file: Path to the AIMET config file specifying quantization settings for specific layers. If not provided, default quantization settings will be applied to all layers. - @type default_data_type: QuantizationDataType - @param default_data_type: Data type to use for quantization (e.g. integer or float). Defaults to `QuantizationDataType.int`. - - @warning: Only in-place quantization is currently supported, - meaning that the original model will be modified and - exported ONNX model will be the quantized version of - the original model. Make sure to keep a backup of the - original model weights if you want to continue training - or exporting the original model after quantization. - """ - - def pass_calibration_data(model: nn.Module) -> None: - for imgs, _ in track( - self.pytorch_loaders["val"], - description="Computing quantization encodings...", - total=len(self.pytorch_loaders["val"]), - ): - model.forward(imgs) save_dir = self.run_save_dir / "aimet" save_dir.mkdir(parents=True, exist_ok=True) cfg = self.cfg.exporter.aimet - epochs = epochs or cfg.epochs - quant_scheme = quant_scheme or cfg.quant_scheme - if isinstance(quant_scheme, str): - quant_scheme = QuantScheme.from_str(quant_scheme) - default_output_bw = default_output_bw or cfg.default_output_bw - default_param_bw = default_param_bw or cfg.default_param_bw - default_data_type = default_data_type or cfg.default_data_type - adaround = adaround if adaround is not None else cfg.adaround.active - adaround_iterations = ( - adaround_iterations or cfg.adaround.default_num_iterations - ) - adaround_reg_param = ( - adaround_reg_param or cfg.adaround.default_reg_param - ) - adaround_beta_range = ( - adaround_beta_range or cfg.adaround.default_beta_range - ) - adaround_warm_start = ( - adaround_warm_start or cfg.adaround.default_warm_start - ) + aimet_config_file = config_file or cfg.config if isinstance(aimet_config_file, dict): with open(save_dir / "aimet_config.json", "w") as f: json.dump(aimet_config_file, f, indent=4) aimet_config_file = str(save_dir / "aimet_config.json") + adaround = adaround if adaround is not None else cfg.adaround.active fold_batch_norms = ( fold_batch_norms if fold_batch_norms is not None @@ -1271,144 +1217,103 @@ def pass_calibration_data(model: nn.Module) -> None: if cross_layer_equalization is not None else cfg.cross_layer_equalization ) + batch_norm_reestimation = ( + batch_norm_reestimation + if batch_norm_reestimation is not None + else cfg.batch_norm_reestimation + ) - model = self.lightning_module + model = deepcopy(self.lightning_module) model.reparametrize() - loader = self.pytorch_loaders["val"] + pre_quant_test = self.pl_trainer.test( + model, self.pytorch_loaders["val"] + )[0] if weights is not None: model.load_checkpoint(weights) - pre_quant_test = self.test(view="val") - - inputs = { + dummy_inputs = { input_name: torch.randn([1, *shape]).to(model.device) for shapes in model.nodes.loader_input_shapes.values() for input_name, shape in shapes.items() } - if len(inputs) > 1: + if len(dummy_inputs) > 1: raise NotImplementedError( "Quantization is not yet supported for models " "with multiple inputs." ) - input_names = list(inputs.keys()) - output_names = model._get_output_onnx_names(deepcopy(inputs)) - inputs = next(iter(inputs.values())) - - if CUDAAccelerator.is_available(): - inputs = inputs.cuda() - model.cuda() - - if fold_batch_norms and not batch_norm_reestimation: - logger.info("Folding batch norms into preceding layers") - fold_all_batch_norms( - model, input_shapes=inputs.shape, dummy_input=inputs - ) - if cross_layer_equalization: - logger.info("Applying cross-layer equalization") - equalize_model( - model, input_shapes=inputs.shape, dummy_input=inputs - ) - - if adaround: - ada_params = AdaroundParameters( - data_loader=loader, - num_batches=min( - len(loader), - math.ceil(2000 / self.cfg.trainer.batch_size), - ), - default_num_iterations=adaround_iterations, # type: ignore - ) - model = cast( - LuxonisLightningModule, - Adaround.apply_adaround( - model, - inputs, - ada_params, - path=str(save_dir), - filename_prefix="adaround", - ), - ) - - sim = QuantizationSimModel( - model=model, - dummy_input=inputs, - quant_scheme=quant_scheme, - default_output_bw=default_output_bw, - default_param_bw=default_param_bw, - config_file=config_file, - default_data_type=default_data_type, - in_place=True, + input_names = list(dummy_inputs.keys()) + output_names = model._get_output_onnx_names(deepcopy(dummy_inputs)) + dummy_inputs = next(iter(dummy_inputs.values())) + + sim = post_training_quantization( + model, + dummy_inputs, + self.pytorch_loaders["val"], + save_dir, + quant_scheme or cfg.quant_scheme, + default_output_bw or cfg.default_output_bw, + default_param_bw or cfg.default_param_bw, + default_data_type or cfg.default_data_type, + aimet_config_file, + adaround, + adaround_iterations or cfg.adaround.default_num_iterations, + adaround_reg_param or cfg.adaround.default_reg_param, + adaround_beta_range or cfg.adaround.default_beta_range, + adaround_warm_start or cfg.adaround.default_warm_start, + fold_batch_norms, + cross_layer_equalization, + batch_norm_reestimation, ) - model = cast(LuxonisLightningModule, sim.model) - - if adaround: - sim.set_and_freeze_param_encodings( - str(save_dir / "adaround.encodings") - ) - - sim.compute_encodings(pass_calibration_data) - - ptq_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] - model.train() - if CUDAAccelerator.is_available(): - model.cuda() - model.automatic_optimization = False + ptq_test = self.pl_trainer.test( + sim.model, # type: ignore + self.pytorch_loaders["val"], + )[0] if optimizer is None: opt_cfg = cfg.optimizer or self.cfg.trainer.optimizer optimizer = from_registry( OPTIMIZERS, opt_cfg.name, - params=model.parameters(), + params=sim.model.parameters(), **opt_cfg.params, ) - if scheduler is None: - scheduler = StepLR(optimizer, step_size=5, gamma=0.1) - - for e in track( - range(epochs), description="Running Quantization-Aware Training..." - ): - for imgs, labels in self.pytorch_loaders["train"]: - optimizer.zero_grad() - loss = model.training_step((imgs, labels)) - model.manual_backward(loss) - optimizer.step() - scheduler.step(e) - - if batch_norm_reestimation: - logger.info("Reestimating batch norm statistics") - - def _forward_pass( - model: nn.Module, inputs: LuxonisLoaderTorchOutput - ) -> Any: - return model(inputs[0]) - - reestimate_bn_stats( - model, self.pytorch_loaders["train"], forward_fn=_forward_pass + sch_cfg = cfg.scheduler or ConfigItem( + name="StepLR", params={"step_size": 5, "gamma": 0.1} + ) + scheduler = from_registry( + SCHEDULERS, + sch_cfg.name, + optimizer=optimizer, + **sch_cfg.params, ) - if fold_batch_norms: - logger.info("Folding batch norms into preceding layers") - fold_all_batch_norms( - model, input_shapes=inputs.shape, dummy_input=inputs - ) + model = quantization_aware_training( + sim, + dummy_inputs, + self.pytorch_loaders["train"], + optimizer, + scheduler, + epochs or cfg.epochs, + fold_batch_norms, + batch_norm_reestimation, + ) - model.automatic_optimization = True - model.eval() qat_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] + model.set_export_mode(mode=True) sim.onnx.export( - inputs, + dummy_inputs, (save_dir / self.cfg.model.name).with_suffix(".onnx"), input_names=input_names, output_names=output_names, ) model.set_export_mode(mode=False) + table = [] for key, value in pre_quant_test.items(): log_key = key.replace("test/metric/", "").replace("test/loss/", "") diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py new file mode 100644 index 00000000..4cfbd027 --- /dev/null +++ b/luxonis_train/core/utils/aimet_utils.py @@ -0,0 +1,155 @@ +import math +from pathlib import Path +from typing import Any, cast + +from aimet_torch import QuantizationSimModel +from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters +from aimet_torch.bn_reestimation import reestimate_bn_stats +from aimet_torch.common.defs import QuantizationDataType, QuantScheme +from aimet_torch.cross_layer_equalization import equalize_model +from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms +from lightning.pytorch.accelerators import CUDAAccelerator +from loguru import logger +from rich.progress import track +from torch import Tensor, nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader + +from luxonis_train.lightning import LuxonisLightningModule +from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput + + +def post_training_quantization( + model: LuxonisLightningModule, + dummy_inputs: Tensor, + val_loader: DataLoader, + save_dir: Path, + quant_scheme: str | QuantScheme = QuantScheme.min_max, + default_output_bw: int = 8, + default_param_bw: int = 8, + default_data_type: QuantizationDataType = QuantizationDataType.int, + config_file: str | None = None, + adaround: bool = False, + adaround_iterations: int | None = None, + adaround_reg_param: float = 0.01, + adaround_beta_range: tuple[int, int] = (20, 2), + adaround_warm_start: float = 0.2, + fold_batch_norms: bool = False, + cross_layer_equalization: bool = False, + batch_norm_reestimation: bool = False, +) -> QuantizationSimModel: + def pass_calibration_data(model: nn.Module) -> None: + for imgs, _ in track( + val_loader, + description="Computing quantization encodings...", + total=len(val_loader), + ): + model.forward(imgs) + + if CUDAAccelerator.is_available(): + dummy_inputs = dummy_inputs.cuda() + model.cuda() + + if fold_batch_norms and not batch_norm_reestimation: + logger.info("Folding batch norms into preceding layers") + fold_all_batch_norms( + model, input_shapes=dummy_inputs.shape, dummy_input=dummy_inputs + ) + if cross_layer_equalization: + logger.info("Applying cross-layer equalization") + equalize_model( + model, input_shapes=dummy_inputs.shape, dummy_input=dummy_inputs + ) + + if adaround: + ada_params = AdaroundParameters( + data_loader=val_loader, + num_batches=min( + len(val_loader), + math.ceil(2000 / val_loader.batch_size), # type: ignore + ), + default_num_iterations=adaround_iterations, # type: ignore + default_reg_param=adaround_reg_param, + default_beta_range=adaround_beta_range, + default_warm_start=adaround_warm_start, + ) + model = cast( + LuxonisLightningModule, + Adaround.apply_adaround( + model, + dummy_inputs, + ada_params, + path=str(save_dir), + filename_prefix="adaround", + ), + ) + + sim = QuantizationSimModel( + model=model, + dummy_input=dummy_inputs, + quant_scheme=quant_scheme, + default_output_bw=default_output_bw, + default_param_bw=default_param_bw, + config_file=config_file, + default_data_type=default_data_type, + in_place=True, + ) + + if adaround: + sim.set_and_freeze_param_encodings( + str(save_dir / "adaround.encodings") + ) + + sim.compute_encodings(pass_calibration_data) + return sim + + +def quantization_aware_training( + sim: QuantizationSimModel, + dummy_inputs: Tensor, + train_loader: DataLoader, + optimizer: Optimizer, + scheduler: LRScheduler, + epochs: int, + fold_batch_norms: bool = False, + batch_norm_reestimation: bool = False, +) -> LuxonisLightningModule: + + model = cast(LuxonisLightningModule, sim.model) + + model.train() + if CUDAAccelerator.is_available(): + model.cuda() + model.automatic_optimization = False + + for e in track( + range(epochs), description="Running Quantization-Aware Training..." + ): + for imgs, labels in train_loader: + optimizer.zero_grad() + loss = model.training_step((imgs, labels)) + model.manual_backward(loss) + optimizer.step() + scheduler.step(e) + + if batch_norm_reestimation: + logger.info("Reestimating batch norm statistics") + + def _forward_pass( + model: nn.Module, inputs: LuxonisLoaderTorchOutput + ) -> Any: + return model(inputs[0]) + + reestimate_bn_stats(model, train_loader, forward_fn=_forward_pass) + + if fold_batch_norms: + logger.info("Folding batch norms into preceding layers") + fold_all_batch_norms( + model, + input_shapes=dummy_inputs.shape, + dummy_input=dummy_inputs, + ) + + model.automatic_optimization = True + return model diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py index 30c8d6ce..d94290fc 100644 --- a/luxonis_train/lightning/utils.py +++ b/luxonis_train/lightning/utils.py @@ -23,6 +23,7 @@ BaseAttachedModule, ) from luxonis_train.callbacks import LuxonisModelSummary, TrainingManager +from luxonis_train.callbacks.aimet_callback import AIMETCallback from luxonis_train.config import AttachedModuleConfig, Config from luxonis_train.config.config import NodeConfig from luxonis_train.nodes import BaseNode @@ -520,19 +521,7 @@ def build_callbacks( "parameter in the config will be ignored." ) if cfg.exporter.aimet.active: - aimet_cfg = cfg.exporter.aimet - callbacks.append( - from_registry( - CALLBACKS, - "AIMETCallback", - epochs=aimet_cfg.epochs, - quant_scheme=aimet_cfg.quant_scheme, - default_output_bw=aimet_cfg.default_output_bw, - default_param_bw=aimet_cfg.default_param_bw, - default_data_type=aimet_cfg.default_data_type, - config_file=aimet_cfg.config, - ) - ) + callbacks.append(AIMETCallback()) return callbacks From 59936ec4a3c75f07dfacf89bb5f78e62f6c2f6fd Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 09:43:41 +0200 Subject: [PATCH 18/74] docs --- configs/README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/configs/README.md b/configs/README.md index de52b3c7..9919b057 100644 --- a/configs/README.md +++ b/configs/README.md @@ -505,6 +505,7 @@ Here you can define configuration for exporting. | `onnx` | `dict` | `{}` | Options specific for ONNX export. See [ONNX](#onnx) section for details | | `hubai` | `dict` | `{}` | Options for HubAI SDK conversion. See [HubAI](#hubai) section for details | | `blobconverter` | `dict` | `{}` | Options for converting to BLOB format (deprecated). See [Blob](#blob-deprecated) section | +| `aimet` | | `dict` | `{}` | ### `ONNX` @@ -566,6 +567,40 @@ exporter: shaves: 8 ``` +### `AIMET` + +The [AIMET](https://quic.github.io/aimet-pages/releases/latest/index.html) (AI Model Efficiency Toolkit) provides quantization and model export tools. + +| Key | Type | Default value | Description | +| -------------------------- | ------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `active` | `bool` | `False` | Whether to use AIMET for quantization and export | +| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training | +| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | +| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | +| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | +| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | +| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/docs/QuantizationSim.html#quantization-sim-config) for details on the available options. | +| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | +| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | +| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | +| `optimizer` | `dict` | `{}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. If not set, the `trainer` optimizer is used. | +| `scheduler` | `dict` | `{}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples. If not set, the `trainer` scheduler is used. | +| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | + +#### Adaround + +Adaptive rounding (AdaRound) is a rounding mechanism for model weights designed to adapt to the data to improve the accuracy of the quantized model. + +By default, AIMET uses nearest rounding for quantization, in which weight values are quantized to the nearest integer value. AdaRound, however, uses training data to determine how to round quantized weights. This technique often improves the accuracy of the quantized model. + +| Key | Type | Default value | Description | +| ------------------------ | ----------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `active` | `bool` | `False` | Whether to use AdaRound for weight rounding during quantization | +| `default_num_iterations` | `int \| None` | `None` | Number of iterations for the AdaRound optimization. The default value is 10K for models with 8- or higher bit weights, and 15K for models with lower than 8 bit weights. | +| `default_reg_param` | `float` | `0.01` | Regularization parameter, trading off between rounding loss vs reconstruction loss. | +| `default_beta_range` | `tuple[int, int]` | `(20, 2)` | Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta). | +| `default_warm_start` | `float` | `0.2` | The warm up period, during which rounding loss has zero effect. | + ## Tuner Here you can specify options for tuning. From ea0755069106517a99d7fece32db393cef61463b Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 09:46:33 +0200 Subject: [PATCH 19/74] test --- luxonis_train/callbacks/gradcam_visualizer.py | 4 +++- tests/integration/test_callbacks.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/luxonis_train/callbacks/gradcam_visualizer.py b/luxonis_train/callbacks/gradcam_visualizer.py index ed26b08e..6c40e5a8 100644 --- a/luxonis_train/callbacks/gradcam_visualizer.py +++ b/luxonis_train/callbacks/gradcam_visualizer.py @@ -142,7 +142,9 @@ def on_validation_batch_end( @param batch_idx: The index of the batch. """ if batch_idx < self.log_n_batches: - images = batch[0][pl_module.image_source] + images = batch[0] + if isinstance(images, dict): + images = images[pl_module.image_source] self.visualize_gradients(trainer, pl_module, images, batch_idx) def visualize_gradients( diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index afea91c0..0693d771 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -56,6 +56,13 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "exporter.scale_values": [0.5, 0.5, 0.5], "exporter.mean_values": [0.5, 0.5, 0.5], "exporter.blobconverter.active": True, + "exporter.aimet": { + "active": True, + "fold_batch_norms": True, + "batch_norm_reestimation": True, + "cross_layer_equalization": True, + }, + "exporter.aimet.adaround.active": True, "loader.params.dataset_name": coco_dataset.identifier, } model = LuxonisModel(config_file, opts, debug_mode=True) From ea3f8cd9420072cf87c9bfd8aa00645d7b0860a6 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 10:10:13 +0200 Subject: [PATCH 20/74] fix tests --- .../attached_modules/visualizers/base_visualizer.py | 5 +++++ .../attached_modules/visualizers/embeddings_visualizer.py | 4 +--- .../attached_modules/visualizers/segmentation_visualizer.py | 2 -- luxonis_train/callbacks/gradcam_visualizer.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/luxonis_train/attached_modules/visualizers/base_visualizer.py b/luxonis_train/attached_modules/visualizers/base_visualizer.py index ed7f9cc8..618ed221 100644 --- a/luxonis_train/attached_modules/visualizers/base_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/base_visualizer.py @@ -3,6 +3,7 @@ from inspect import Parameter import torch.nn.functional as F +from luxonis_ml.data.utils import ColorMap from torch import Tensor from typing_extensions import TypeVarTuple, Unpack @@ -34,6 +35,10 @@ def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor: align_corners=False, ) + @cached_property + def colormap(self) -> ColorMap: + return ColorMap() + @abstractmethod def forward( self, diff --git a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py index 772c632d..d9f4cd7b 100644 --- a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py @@ -3,7 +3,6 @@ import numpy as np import seaborn as sns from loguru import logger -from luxonis_ml.data.utils import ColorMap from matplotlib import pyplot as plt from sklearn.decomposition import PCA from torch import Tensor @@ -25,11 +24,10 @@ def __init__(self, z_score_threshold: float = 3, **kwargs): outliers. """ super().__init__(**kwargs) - self.colors = ColorMap() self.z_score_threshold = z_score_threshold def _get_color(self, label: int) -> tuple[float, float, float]: - r, g, b = self.colors[label] + r, g, b = self.colormap[label] return r / 255, g / 255, b / 255 def forward( diff --git a/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py b/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py index 6a457b22..92eda5d1 100644 --- a/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/segmentation_visualizer.py @@ -2,7 +2,6 @@ import torch from loguru import logger -from luxonis_ml.data.utils.visualizations import ColorMap from torch import Tensor from typing_extensions import override @@ -47,7 +46,6 @@ def __init__( self.background_class = background_class self.background_color = background_color self.alpha = alpha - self.colormap = ColorMap() self._warn_colors = True diff --git a/luxonis_train/callbacks/gradcam_visualizer.py b/luxonis_train/callbacks/gradcam_visualizer.py index 6c40e5a8..95ec67f0 100644 --- a/luxonis_train/callbacks/gradcam_visualizer.py +++ b/luxonis_train/callbacks/gradcam_visualizer.py @@ -47,7 +47,7 @@ def forward(self, inputs: Tensor, *args, **kwargs) -> Tensor: @return: The processed output based on the task type. """ input_dict = {"image": inputs} - output = self.pl_module(input_dict, *args, **kwargs) + output = self.pl_module._run_forward(input_dict, *args, **kwargs) if len(output.outputs) > 1: logger.warning( "Model has multiple heads. Using the first head for Grad-CAM." From 2716b1ea0c89fab30a0cb8aeec94148e3897deff Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 16:42:08 +0200 Subject: [PATCH 21/74] fixes --- .../losses/reconstruction_segmentation_loss.py | 6 ++++++ .../recognition_confusion_matrix.py | 7 +++++++ .../attached_modules/metrics/torchmetrics.py | 4 ++++ luxonis_train/core/core.py | 14 ++++++-------- luxonis_train/core/utils/aimet_utils.py | 12 ++++++++++-- luxonis_train/loaders/debug_loader.py | 4 +++- luxonis_train/nodes/blocks/__init__.py | 8 ++++++++ luxonis_train/nodes/blocks/blocks.py | 5 +---- luxonis_train/utils/dataset_metadata.py | 6 +----- 9 files changed, 46 insertions(+), 20 deletions(-) diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index f71cd762..77622d01 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -1,3 +1,4 @@ +from contextlib import suppress from math import exp from typing import Literal @@ -84,6 +85,11 @@ def __init__( self.channel = 1 self.window = create_window(window_size) + with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(self.__class__) + def forward(self, img1: Tensor, img2: Tensor) -> Tensor: device = img1.device with amp.autocast(device_type=device.type, enabled=False): diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py index 261e2cc4..87f5d6d0 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py @@ -1,3 +1,5 @@ +from contextlib import suppress + from torch import Tensor from torchmetrics.classification import ( BinaryConfusionMatrix, @@ -27,6 +29,11 @@ def __init__(self, **kwargs): else: self.metric = MulticlassConfusionMatrix(num_classes=self.n_classes) + with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(self.metric.__class__) + @override def update(self, predictions: Tensor, targets: Tensor) -> None: if self.n_classes > 1: diff --git a/luxonis_train/attached_modules/metrics/torchmetrics.py b/luxonis_train/attached_modules/metrics/torchmetrics.py index 95efb5d0..1a233314 100644 --- a/luxonis_train/attached_modules/metrics/torchmetrics.py +++ b/luxonis_train/attached_modules/metrics/torchmetrics.py @@ -52,6 +52,10 @@ def __init__(self, **kwargs): kwargs["num_labels"] = n_classes self.metric = self.Metric(**kwargs) + with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(self.metric.__class__) @override def update(self, predictions: Tensor, target: Tensor) -> None: diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index add9e613..7458e280 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -4,7 +4,7 @@ from copy import deepcopy from pathlib import Path from threading import ExceptHookArgs, Thread -from typing import Literal, overload +from typing import Literal, cast, overload import lightning.pytorch as pl import lightning_utilities.core.rank_zero as rank_zero_module @@ -1224,7 +1224,7 @@ def quantize( ) model = deepcopy(self.lightning_module) - model.reparametrize() + model.reparametrize().eval() pre_quant_test = self.pl_trainer.test( model, self.pytorch_loaders["val"] )[0] @@ -1266,11 +1266,10 @@ def quantize( cross_layer_equalization, batch_norm_reestimation, ) + model = cast(LuxonisLightningModule, sim.model) - ptq_test = self.pl_trainer.test( - sim.model, # type: ignore - self.pytorch_loaders["val"], - )[0] + model.eval() + ptq_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] if optimizer is None: opt_cfg = cfg.optimizer or self.cfg.trainer.optimizer @@ -1300,7 +1299,7 @@ def quantize( epochs or cfg.epochs, fold_batch_norms, batch_norm_reestimation, - ) + ).eval() qat_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] @@ -1312,7 +1311,6 @@ def quantize( input_names=input_names, output_names=output_names, ) - model.set_export_mode(mode=False) table = [] for key, value in pre_quant_test.items(): diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index 4cfbd027..16b9d7f5 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -39,10 +39,14 @@ def post_training_quantization( cross_layer_equalization: bool = False, batch_norm_reestimation: bool = False, ) -> QuantizationSimModel: + def pass_calibration_data(model: nn.Module) -> None: + assert len(val_loader) > 0, ( + "Validation loader must have at least one batch" + ) for imgs, _ in track( val_loader, - description="Computing quantization encodings...", + description="Computing quantization encodings", total=len(val_loader), ): model.forward(imgs) @@ -51,6 +55,8 @@ def pass_calibration_data(model: nn.Module) -> None: dummy_inputs = dummy_inputs.cuda() model.cuda() + model.eval() + if fold_batch_norms and not batch_norm_reestimation: logger.info("Folding batch norms into preceding layers") fold_all_batch_norms( @@ -124,7 +130,9 @@ def quantization_aware_training( model.automatic_optimization = False for e in track( - range(epochs), description="Running Quantization-Aware Training..." + range(epochs), + description="Running Quantization-Aware Training", + total=epochs, ): for imgs, labels in train_loader: optimizer.zero_grad() diff --git a/luxonis_train/loaders/debug_loader.py b/luxonis_train/loaders/debug_loader.py index 96d8cee8..a1a89aef 100644 --- a/luxonis_train/loaders/debug_loader.py +++ b/luxonis_train/loaders/debug_loader.py @@ -71,7 +71,9 @@ def __len__(self) -> int: return self.batch_size * 10 @override - def get(self, idx: int) -> tuple[Tensor | dict[str, Tensor], Labels]: + def __getitem__( + self, idx: int + ) -> tuple[Tensor | dict[str, Tensor], Labels]: img = torch.zeros(self.n_channels, self.height, self.width) label_shapes = self.get_label_shapes(self.labels) labels = { diff --git a/luxonis_train/nodes/blocks/__init__.py b/luxonis_train/nodes/blocks/__init__.py index 9e9eb783..ffc323dd 100644 --- a/luxonis_train/nodes/blocks/__init__.py +++ b/luxonis_train/nodes/blocks/__init__.py @@ -1,3 +1,5 @@ +from contextlib import suppress + from .blocks import ( DFL, AttentionRefinmentBlock, @@ -27,6 +29,12 @@ UpBlock, ) +with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin + + QuantizationMixin.ignore(DropPath) + QuantizationMixin.ignore(UpscaleOnline) + __all__ = [ "DFL", "AttentionRefinmentBlock", diff --git a/luxonis_train/nodes/blocks/blocks.py b/luxonis_train/nodes/blocks/blocks.py index a7722e1f..172104d8 100644 --- a/luxonis_train/nodes/blocks/blocks.py +++ b/luxonis_train/nodes/blocks/blocks.py @@ -489,10 +489,7 @@ def reparametrize(self) -> None: @override def restore(self) -> None: if self.fused_branch is None: - raise RuntimeError( - f"Cannot restore '{self.name}' " - "that has not yet been reparametrized." - ) + return # Not sure if this is necessary for param in self.fused_branch.parameters(): diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index d016e322..24a888fa 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -129,11 +129,7 @@ def n_keypoints(self, task_name: str | None = None) -> int: task types. """ if task_name is not None: - if task_name not in self._n_keypoints: - raise ValueError( - f"Task '{task_name}' is not present in the dataset." - ) - return self._n_keypoints[task_name] + return self._n_keypoints.get(task_name, 0) n_keypoints = next(iter(self._n_keypoints.values())) for n in self._n_keypoints.values(): if n != n_keypoints: From 9bccbc8f0afaf23ae71c32bf0710a4035f522b33 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 17:01:01 +0200 Subject: [PATCH 22/74] fixed config --- luxonis_train/config/config.py | 20 +++++++++++++++++++- luxonis_train/core/core.py | 9 +++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index c0530e5e..d9b2d41f 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -750,7 +750,25 @@ class AIMETConfig(BaseModelExtraForbid): batch_norm_reestimation: bool = False adaround: AdaroundConfig = Field(default_factory=AdaroundConfig) optimizer: ConfigItem | None = None - scheduler: ConfigItem | None = None + scheduler: ConfigItem = Field( + default_factory=lambda: ConfigItem( + name="StepLR", params={"step_size": 5, "gamma": 0.1} + ) + ) + + @field_validator("quant_scheme", mode="before") + @classmethod + def validate_quant_scheme(cls, value: ParamValue) -> Any: + if isinstance(value, str): + return QuantScheme.from_str(value) + return value + + @field_validator("default_data_type", mode="before") + @classmethod + def validate_default_data_type(cls, value: ParamValue) -> Any: + if isinstance(value, str): + return QuantizationDataType[value.lower()] + return value @field_validator("config", mode="before") @classmethod diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 7458e280..fe916975 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -19,7 +19,7 @@ from luxonis_ml.data import LuxonisDataset from luxonis_ml.nn_archive import ArchiveGenerator from luxonis_ml.nn_archive.config import CONFIG_VERSION -from luxonis_ml.typing import ConfigItem, Params, PathType +from luxonis_ml.typing import Params, PathType from luxonis_ml.utils import Environ, LuxonisFileSystem from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -1280,14 +1280,11 @@ def quantize( **opt_cfg.params, ) if scheduler is None: - sch_cfg = cfg.scheduler or ConfigItem( - name="StepLR", params={"step_size": 5, "gamma": 0.1} - ) scheduler = from_registry( SCHEDULERS, - sch_cfg.name, + cfg.scheduler.name, optimizer=optimizer, - **sch_cfg.params, + **cfg.scheduler.params, ) model = quantization_aware_training( From d200891f8a05b4c1be276d16516a5ff50cf5c4b3 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 17:04:30 +0200 Subject: [PATCH 23/74] updated api --- luxonis_train/lightning/luxonis_lightning.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 3f66b519..466f21e8 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -15,7 +15,7 @@ from semver import Version from torch import Size, Tensor from torch.nn.modules.module import _IncompatibleKeys -from typing_extensions import override +from typing_extensions import Self, override import luxonis_train from luxonis_train.attached_modules.visualizers import ( @@ -330,15 +330,17 @@ def _run_forward( outputs=outputs_dict, losses=losses, visualizations=visualizations ) - def set_export_mode(self, *, mode: bool) -> None: + def set_export_mode(self, mode: bool) -> Self: for module in self.modules(): if isinstance(module, BaseNode): module.set_export_mode(mode=mode) + return self - def reparametrize(self) -> None: + def reparametrize(self) -> Self: for module in self.modules(): if isinstance(module, Reparametrizable): module.reparametrize() + return self def export_onnx(self, save_path: PathType, **kwargs) -> Path: """Exports the model to ONNX format. From d34ca9172f4a0d379c2a8668c19330f001215e53 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 17:17:09 +0200 Subject: [PATCH 24/74] cleanup --- luxonis_train/core/core.py | 37 ++++++++++++++++++++++++------------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index fe916975..d1833a22 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -23,6 +23,7 @@ from luxonis_ml.utils import Environ, LuxonisFileSystem from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data.dataloader import DataLoader from typeguard import typechecked from luxonis_train.callbacks import ( @@ -315,6 +316,18 @@ def __init__( self.lightning_module.load_checkpoint(weights) self._exported_models: dict[str, Path] = {} + @property + def train_loader(self) -> DataLoader: + return self.pytorch_loaders["train"] + + @property + def val_loader(self) -> DataLoader: + return self.pytorch_loaders["val"] + + @property + def test_loader(self) -> DataLoader: + return self.pytorch_loaders["test"] + def _train(self, resume: PathType | None, *args, **kwargs) -> None: status = "success" try: @@ -381,8 +394,8 @@ def train( self._train( resume_weights, self.lightning_module, - self.pytorch_loaders["train"], - self.pytorch_loaders["val"], + self.train_loader, + self.val_loader, ) logger.info("Training finished") logger.info(f"Checkpoints saved in: {self.run_save_dir}") @@ -399,8 +412,8 @@ def thread_exception_hook(args: ExceptHookArgs) -> None: args=( resume_weights, self.lightning_module, - self.pytorch_loaders["train"], - self.pytorch_loaders["val"], + self.train_loader, + self.val_loader, ), daemon=True, ) @@ -830,8 +843,8 @@ def _objective(trial: optuna.trial.Trial) -> float: try: pl_trainer.fit( lightning_module, - self.pytorch_loaders["train"], - self.pytorch_loaders["val"], + self.train_loader, + self.val_loader, ) pruner_callback.check_pruned() @@ -1225,9 +1238,7 @@ def quantize( model = deepcopy(self.lightning_module) model.reparametrize().eval() - pre_quant_test = self.pl_trainer.test( - model, self.pytorch_loaders["val"] - )[0] + pre_quant_test = self.pl_trainer.test(model, self.val_loader)[0] if weights is not None: model.load_checkpoint(weights) @@ -1250,7 +1261,7 @@ def quantize( sim = post_training_quantization( model, dummy_inputs, - self.pytorch_loaders["val"], + self.val_loader, save_dir, quant_scheme or cfg.quant_scheme, default_output_bw or cfg.default_output_bw, @@ -1269,7 +1280,7 @@ def quantize( model = cast(LuxonisLightningModule, sim.model) model.eval() - ptq_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] + ptq_test = self.pl_trainer.test(model, self.val_loader)[0] if optimizer is None: opt_cfg = cfg.optimizer or self.cfg.trainer.optimizer @@ -1290,7 +1301,7 @@ def quantize( model = quantization_aware_training( sim, dummy_inputs, - self.pytorch_loaders["train"], + self.train_loader, optimizer, scheduler, epochs or cfg.epochs, @@ -1298,7 +1309,7 @@ def quantize( batch_norm_reestimation, ).eval() - qat_test = self.pl_trainer.test(model, self.pytorch_loaders["val"])[0] + qat_test = self.pl_trainer.test(model, self.val_loader)[0] model.set_export_mode(mode=True) From 1140c36af3f7e505f97f30e01dfebc0a8f34a12e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 17:17:15 +0200 Subject: [PATCH 25/74] cli fix --- luxonis_train/__main__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index 3c9e6f8c..6e6d356d 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -441,10 +441,9 @@ def quantize( *, config: str | None = None, weights: str | None = None, - epochs: int = 4, ): model = create_model(config, opts, weights=weights, debug_mode=True) - model.quantize(epochs=epochs) + model.quantize() @upgrade_app.command() From fa5c9345efae52221e1b1b1e40059e5ae7deb664 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 17:17:28 +0200 Subject: [PATCH 26/74] scheduler step remove epoch --- luxonis_train/core/utils/aimet_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index 16b9d7f5..dd59704f 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -129,7 +129,7 @@ def quantization_aware_training( model.cuda() model.automatic_optimization = False - for e in track( + for _ in track( range(epochs), description="Running Quantization-Aware Training", total=epochs, @@ -139,7 +139,7 @@ def quantization_aware_training( loss = model.training_step((imgs, labels)) model.manual_backward(loss) optimizer.step() - scheduler.step(e) + scheduler.step() if batch_norm_reestimation: logger.info("Reestimating batch norm statistics") From 5180a5ee79ef73ec9141da89d63a6c9a5a82775c Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 7 Apr 2026 21:25:06 +0200 Subject: [PATCH 27/74] fix qat --- luxonis_train/config/config.py | 4 +++- luxonis_train/core/utils/aimet_utils.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index d9b2d41f..cd84be43 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -749,7 +749,9 @@ class AIMETConfig(BaseModelExtraForbid): cross_layer_equalization: bool = False batch_norm_reestimation: bool = False adaround: AdaroundConfig = Field(default_factory=AdaroundConfig) - optimizer: ConfigItem | None = None + optimizer: ConfigItem = Field( + default_factory=lambda: ConfigItem(name="SGD", params={"lr": 1e-5}) + ) scheduler: ConfigItem = Field( default_factory=lambda: ConfigItem( name="StepLR", params={"step_size": 5, "gamma": 0.1} diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index dd59704f..cc2a7061 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -137,7 +137,7 @@ def quantization_aware_training( for imgs, labels in train_loader: optimizer.zero_grad() loss = model.training_step((imgs, labels)) - model.manual_backward(loss) + loss.backward() optimizer.step() scheduler.step() From 87f98b5847c60fb59c546760351c1271bdf543b6 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 08:52:53 +0200 Subject: [PATCH 28/74] remmoved hidden node --- luxonis_train/attached_modules/base_attached_module.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 5a5ca0e5..3edb8d19 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -57,7 +57,7 @@ def __init__(self, *, node: BaseNode | None = None, **kwargs): @param kwargs: Additional keyword arguments. """ super().__init__(**kwargs) - self._node = (node,) + self._node = node if node is not None and node.task is not None: if ( @@ -124,13 +124,12 @@ def node(self) -> BaseNode: @raises RuntimeError: If the node was not provided during initialization. """ - node = self._node[0] - if node is None: + if self._node is None: raise RuntimeError( "Attempt to access `node` reference, but it was not " "provided during initialization." ) - return node + return self._node @property def n_keypoints(self) -> int: From ffa5bda85a37ab3c82d3de574af7904131ff685c Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 08:53:19 +0200 Subject: [PATCH 29/74] device switching --- luxonis_train/lightning/luxonis_lightning.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 466f21e8..f688c917 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -294,10 +294,14 @@ def _run_forward( if compute_loss and node.losses and labels is not None: for loss_name, loss in node.losses.items(): + loss.to(self.device) + if self.training: + loss.train() losses[node_name][loss_name] = loss.run(outputs, labels) if compute_metrics and node.metrics and labels is not None: for metric in node.metrics.values(): + metric.to(self.device) metric.run_update(outputs, labels) if ( @@ -306,6 +310,7 @@ def _run_forward( and images is not None ): for viz_name, visualizer in node.visualizers.items(): + visualizer.to(self.device) viz = combine_visualizations( visualizer.run(images, images, outputs, labels), ) From c3597222aebbb8c19b4ef9183fc233be9015ed15 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:02:09 +0200 Subject: [PATCH 30/74] docs --- luxonis_train/callbacks/README.md | 4 ++ luxonis_train/core/core.py | 85 ++++++++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index 40da9503..5fbae7d2 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -203,3 +203,7 @@ Callback that publishes training progress and timing metrics. | `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed | | `train/epoch_duration_sec` | Time elapsed so far in current epoch | | `train/epoch_completion_sec` | Total duration of completed epoch in seconds | + +## `AIMETCallback` + +Callback to perform AIMET quantization at the end of the training. diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index d1833a22..f05bf29a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1206,7 +1206,86 @@ def quantize( batch_norm_reestimation: bool | None = None, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, + in_place: bool = False, ) -> None: + """Runs quantization of the model using AIMET. + + @type weights: PathType | None + @param weights: Path to the checkpoint from which to load + weights. + @type epochs: int | None + @param epochs: Number of epochs to run quantization-aware + training for. + @type quant_scheme: str | QuantScheme | None + @param quant_scheme: Quantization scheme to use. If not + specified, the value from the configuration file will be + used. + @type default_output_bw: int | None + @param default_output_bw: Default bitwidth to use for quantizing + outputs. If not specified, the value from the configuration + file will be used. + @type default_param_bw: int | None + @param default_param_bw: Default bitwidth to use for quantizing + parameters. If not specified, the value from the + configuration file will be used. + @type config_file: str | None + @param config_file: Path to the AIMET configuration file or a + dictionary containing the AIMET configuration. If not + specified, the value from the configuration file will be + used. + @type default_data_type: QuantizationDataType | None + @param default_data_type: Default data type to use for + quantization. If not specified, the value from the + configuration file will be used. + @type adaround: bool | None + @param adaround: Whether to use Adaround for weight + quantization. If not specified, the value from the + configuration file will be used. + @type adaround_iterations: int | None + @param adaround_iterations: Number of iterations to run Adaround + for. If not specified, the value from the configuration file + will be used. + @type adaround_reg_param: float | None + @param adaround_reg_param: Regularization parameter to use for + Adaround. If not specified, the value from the configuration + file will be used. + @type adaround_beta_range: tuple[int, int] | None + @param adaround_beta_range: Beta range to use for Adaround. If + not specified, the value from the configuration file will be + used. + @type adaround_warm_start: float | None = None + @param adaround_warm_start: Warm start value to use for + Adaround. If not specified, the value from the configuration + file will be used. + @type fold_batch_norms: bool | None + @param fold_batch_norms: Whether to fold batch norms before + quantization. If not specified, the value from the + configuration file will be used. + @type cross_layer_equalization: bool | None + @param cross_layer_equalization: Whether to perform cross-layer + equalization before quantization. If not specified, the + value from the configuration file will be used. + @type batch_norm_reestimation: bool | None + @param batch_norm_reestimation: Whether to perform batch norm + reestimation after folding batch norms. If not specified, + the value from the configuration file will be used. + @type optimizer: Optimizer | None + @param optimizer: Optimizer to use for quantization-aware + training. If not specified, the optimizer from the + configuration file will be used. + @type scheduler: LRScheduler | None + @param scheduler: Learning rate scheduler to use for + quantization-aware training. If not specified, the scheduler + from the configuration file will be used. + @type in_place: bool + @param in_place: Whether to perform quantization in-place on the + original model or to create a copy of the model for + quantization. Defaults to False, which means that a copy of + the model will be created for quantization. Setting this to + True will modify the original model in-place, which may save + memory but will overwrite the original model's weights and + structure. + """ save_dir = self.run_save_dir / "aimet" save_dir.mkdir(parents=True, exist_ok=True) @@ -1236,7 +1315,11 @@ def quantize( else cfg.batch_norm_reestimation ) - model = deepcopy(self.lightning_module) + if not in_place: + model = deepcopy(self.lightning_module) + else: + model = self.lightning_module + model.reparametrize().eval() pre_quant_test = self.pl_trainer.test(model, self.val_loader)[0] From b77c1264d19c5a1d31191621a40d1f327303c038 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:09:00 +0200 Subject: [PATCH 31/74] renamed --- luxonis_train/callbacks/gradcam_visualizer.py | 2 +- luxonis_train/lightning/luxonis_lightning.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/luxonis_train/callbacks/gradcam_visualizer.py b/luxonis_train/callbacks/gradcam_visualizer.py index 95ec67f0..0f23ff77 100644 --- a/luxonis_train/callbacks/gradcam_visualizer.py +++ b/luxonis_train/callbacks/gradcam_visualizer.py @@ -47,7 +47,7 @@ def forward(self, inputs: Tensor, *args, **kwargs) -> Tensor: @return: The processed output based on the task type. """ input_dict = {"image": inputs} - output = self.pl_module._run_forward(input_dict, *args, **kwargs) + output = self.pl_module.full_forward(input_dict, *args, **kwargs) if len(output.outputs) > 1: logger.warning( "Model has multiple heads. Using the first head for Grad-CAM." diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index f688c917..5b547998 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -201,7 +201,7 @@ def forward( @rtype: dict[str, L{Packet}[L{Tensor}]] @return: Output of the model. """ - outputs = self._run_forward( + outputs = self.full_forward( inputs, compute_loss=False, compute_metrics=False, @@ -226,7 +226,7 @@ def forward( return tuple(new_outputs) - def _run_forward( + def full_forward( self, inputs: dict[str, Tensor] | Tensor, labels: Labels | None = None, @@ -397,7 +397,7 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path: def training_step( self, train_batch: tuple[dict[str, Tensor], Labels] ) -> Tensor: - outputs = self._run_forward(*train_batch) + outputs = self.full_forward(*train_batch) if not outputs.losses: raise ValueError("Losses are empty, check if you defined any loss") @@ -423,7 +423,7 @@ def predict_step( ) -> LuxonisOutput: inputs, labels = batch images = get_denormalized_images(self.cfg, inputs[self.image_source]) - return self._run_forward( + return self.full_forward( inputs, labels, images=images, @@ -645,7 +645,7 @@ def _evaluation_step( if self._n_logged_images < max_log_images: images = get_denormalized_images(self.cfg, input_image) - outputs = self._run_forward( + outputs = self.full_forward( inputs, labels, images=images, @@ -978,7 +978,7 @@ def _strip_state_prefix(key: str) -> str: return ".".join(key.split(".")[idx:]) def _get_output_onnx_names(self, inputs: dict[str, Tensor]) -> list[str]: - outputs = self._run_forward(inputs).outputs + outputs = self.full_forward(inputs).outputs output_order = sorted( [ (node_name, output_name, i) From 068d8f19547d24f083aa2bf8e42f59ed0b40f19b Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:15:21 +0200 Subject: [PATCH 32/74] removed ignores --- luxonis_train/assigners/__init__.py | 9 --------- .../attached_modules/base_attached_module.py | 4 ---- .../losses/efficient_keypoint_bbox_loss.py | 20 +++++++++---------- .../reconstruction_segmentation_loss.py | 6 ------ .../recognition_confusion_matrix.py | 7 ------- .../attached_modules/metrics/torchmetrics.py | 4 ---- luxonis_train/lightning/utils.py | 10 +++------- 7 files changed, 12 insertions(+), 48 deletions(-) diff --git a/luxonis_train/assigners/__init__.py b/luxonis_train/assigners/__init__.py index 6aa863c9..0b8b074a 100644 --- a/luxonis_train/assigners/__init__.py +++ b/luxonis_train/assigners/__init__.py @@ -1,13 +1,4 @@ -from contextlib import suppress - from .atss_assigner import ATSSAssigner from .tal_assigner import TaskAlignedAssigner -with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(ATSSAssigner) - QuantizationMixin.ignore(TaskAlignedAssigner) - - __all__ = ["ATSSAssigner", "TaskAlignedAssigner"] diff --git a/luxonis_train/attached_modules/base_attached_module.py b/luxonis_train/attached_modules/base_attached_module.py index 3edb8d19..67487ddf 100644 --- a/luxonis_train/attached_modules/base_attached_module.py +++ b/luxonis_train/attached_modules/base_attached_module.py @@ -81,10 +81,6 @@ def __init__(self, *, node: BaseNode | None = None, **kwargs): self._task = None self._check_node_type_override() - with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(self.__class__) @property def current_epoch(self) -> int: diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 3e27c684..4e27edec 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -83,6 +83,13 @@ def __init__( ) self.regr_kpts_loss_weight = regr_kpts_loss_weight self.vis_kpts_loss_weight = vis_kpts_loss_weight + self.register_buffer( + "gt_kpts_scale", + torch.tensor( + [self.original_img_size[1], self.original_img_size[0]], + ), + persistent=False, + ) def forward( self, @@ -93,14 +100,14 @@ def forward( target_boundingbox: Tensor, target_keypoints: Tensor, ) -> tuple[Tensor, dict[str, Tensor]]: + self._init_parameters(features) + device = keypoints_raw.device target_keypoints = insert_class(target_keypoints, target_boundingbox) batch_size = class_scores.shape[0] n_kpts = (target_keypoints.shape[1] - 2) // 3 - self._init_parameters(features) - pred_bboxes = dist2bbox(distributions, self.anchor_points_strided) keypoints_raw = self.dist2kpts_noscale( self.anchor_points_strided, @@ -262,12 +269,3 @@ def dist2kpts_noscale(self, anchor_points: Tensor, kpts: Tensor) -> Tensor: adj_kpts[..., 0] += x_adj adj_kpts[..., 1] += y_adj return adj_kpts - - def _init_parameters(self, features: list[Tensor]) -> None: - if hasattr(self, "gt_kpts_scale"): - return - super()._init_parameters(features) - self.gt_kpts_scale = torch.tensor( - [self.original_img_size[1], self.original_img_size[0]], - device=features[0].device, - ) diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index 77622d01..f71cd762 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -1,4 +1,3 @@ -from contextlib import suppress from math import exp from typing import Literal @@ -85,11 +84,6 @@ def __init__( self.channel = 1 self.window = create_window(window_size) - with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(self.__class__) - def forward(self, img1: Tensor, img2: Tensor) -> Tensor: device = img1.device with amp.autocast(device_type=device.type, enabled=False): diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py index 87f5d6d0..261e2cc4 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/recognition_confusion_matrix.py @@ -1,5 +1,3 @@ -from contextlib import suppress - from torch import Tensor from torchmetrics.classification import ( BinaryConfusionMatrix, @@ -29,11 +27,6 @@ def __init__(self, **kwargs): else: self.metric = MulticlassConfusionMatrix(num_classes=self.n_classes) - with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(self.metric.__class__) - @override def update(self, predictions: Tensor, targets: Tensor) -> None: if self.n_classes > 1: diff --git a/luxonis_train/attached_modules/metrics/torchmetrics.py b/luxonis_train/attached_modules/metrics/torchmetrics.py index 1a233314..95efb5d0 100644 --- a/luxonis_train/attached_modules/metrics/torchmetrics.py +++ b/luxonis_train/attached_modules/metrics/torchmetrics.py @@ -52,10 +52,6 @@ def __init__(self, **kwargs): kwargs["num_labels"] = n_classes self.metric = self.Metric(**kwargs) - with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(self.metric.__class__) @override def update(self, predictions: Tensor, target: Tensor) -> None: diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py index d94290fc..9afff579 100644 --- a/luxonis_train/lightning/utils.py +++ b/luxonis_train/lightning/utils.py @@ -83,9 +83,9 @@ def __init__( super().__init__() self.name = name self.module = module - self.losses = _to_module_dict(losses) - self.metrics = _to_module_dict(metrics) - self.visualizers = _to_module_dict(visualizers) + self.losses = losses + self.metrics = metrics + self.visualizers = visualizers self.unfreeze_after = unfreeze_after self.lr_after_unfreeze = lr_after_unfreeze self.inputs = inputs or [] @@ -568,10 +568,6 @@ def _init_attached_module( A = TypeVar("A", BaseLoss, BaseMetric, BaseVisualizer) -def _to_module_dict(modules: dict[str, A]) -> dict[str, A]: - return nn.ModuleDict(modules) # type: ignore - - def log_balanced_class_images( tracker: LuxonisTrackerPL, nodes: Nodes, From 5769e46257a308daa9b82c779bf18c4fb271d78d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:15:26 +0200 Subject: [PATCH 33/74] updated tests --- tests/integration/test_callbacks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index 0693d771..ee99e44d 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -58,6 +58,7 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "exporter.blobconverter.active": True, "exporter.aimet": { "active": True, + "default_num_iterations": 1, "fold_batch_norms": True, "batch_norm_reestimation": True, "cross_layer_equalization": True, From fb72ec6f9cb735475e25b457e8f67e9674a25b6b Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:49:59 +0200 Subject: [PATCH 34/74] affine quant --- .../nodes/backbones/pplcnet_v3/blocks.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py index 57f3dc63..e41d6fbf 100644 --- a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py +++ b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py @@ -1,4 +1,5 @@ import torch +from aimet_torch.v2.nn import QuantizationMixin from torch import Tensor, nn from typeguard import typechecked @@ -31,6 +32,31 @@ def forward(self, x: Tensor) -> Tensor: return self.scale * x + self.bias +@QuantizationMixin.implements(AffineBlock) +class QuantizedAffineBlock(QuantizationMixin, AffineBlock): + def __quant_init__(self): + super().__quant_init__() + + # Declare the number of input/output quantizers + self.input_quantizers = nn.ModuleList([None]) # type: ignore + self.output_quantizers = nn.ModuleList([None]) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + # Quantize input tensors + if self.input_quantizers[0]: + x = self.input_quantizers[0](x) + + # Run forward with quantized inputs and parameters + with self._patch_quantized_parameters(): + ret = super().forward(x) + + # Quantize output tensors + if self.output_quantizers[0]: + ret = self.output_quantizers[0](ret) + + return ret + + class LCNetV3Block(nn.Module): @typechecked def __init__( From 1e95e2abf67520bc25afdfd7906945767636e2ac Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 09:51:25 +0200 Subject: [PATCH 35/74] reordered --- luxonis_train/config/config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index cd84be43..c120972a 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -739,16 +739,19 @@ class AdaroundConfig(BaseModelExtraForbid): class AIMETConfig(BaseModelExtraForbid): active: bool = False - epochs: PositiveInt = 20 + default_output_bw: Literal[4, 8, 16] = 8 default_param_bw: Literal[4, 8, 16] = 8 default_data_type: QuantizationDataType = QuantizationDataType.int quant_scheme: QuantScheme = QuantScheme.min_max config: Params | None = None + fold_batch_norms: bool = False cross_layer_equalization: bool = False batch_norm_reestimation: bool = False adaround: AdaroundConfig = Field(default_factory=AdaroundConfig) + + epochs: PositiveInt = 20 optimizer: ConfigItem = Field( default_factory=lambda: ConfigItem(name="SGD", params={"lr": 1e-5}) ) From 2476ef77d9d612bb2c675612bf09ec5d3492cd89 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 10:02:06 +0200 Subject: [PATCH 36/74] fix types --- luxonis_train/__main__.py | 12 ++++++++---- luxonis_train/core/utils/infer_utils.py | 2 +- tests/unittests/test_loaders/test_base_loader.py | 3 ++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index 6e6d356d..43ddd062 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -142,10 +142,14 @@ def _yield_visualizations( loader = model.loaders[view] for images, labels in loader: - np_images = { - k: v.numpy().transpose(1, 2, 0) for k, v in images.items() - } - main_image = np_images[loader.image_source] + if isinstance(images, dict): + np_images = { + k: v.numpy().transpose(1, 2, 0) for k, v in images.items() + } + main_image = np_images[loader.image_source] + else: + main_image = images.numpy().transpose(1, 2, 0) + main_image = cv2.cvtColor(main_image, cv2.COLOR_RGB2BGR).astype( np.uint8 ) diff --git a/luxonis_train/core/utils/infer_utils.py b/luxonis_train/core/utils/infer_utils.py index 05e7ee2a..55ff3eab 100644 --- a/luxonis_train/core/utils/infer_utils.py +++ b/luxonis_train/core/utils/infer_utils.py @@ -58,7 +58,7 @@ def prepare_and_infer_image( npy_img = model.loaders["val"].augment_test_image(images) torch_img = torch.tensor(npy_img).unsqueeze(0).permute(0, 3, 1, 2).float() - return model.lightning_module.run_forward_step( + return model.lightning_module.full_forward( {model.lightning_module.image_source: torch_img}, images=get_denormalized_images(model.cfg, torch_img), compute_visualizations=True, diff --git a/tests/unittests/test_loaders/test_base_loader.py b/tests/unittests/test_loaders/test_base_loader.py index 4be1e5c9..295f6a4a 100644 --- a/tests/unittests/test_loaders/test_base_loader.py +++ b/tests/unittests/test_loaders/test_base_loader.py @@ -9,7 +9,7 @@ class DummyLoader(BaseLoaderTorch): def __len__(self) -> int: ... - def get(self, idx: int) -> LuxonisLoaderTorchOutput: ... + def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: ... def get_classes(self) -> dict[str, dict[str, int]]: ... @@ -68,6 +68,7 @@ def build_batch_element() -> LuxonisLoaderTorchOutput: inputs, annotations = loader.collate_fn(batch) with subtests.test("inputs"): + assert isinstance(inputs, dict) assert inputs["features"].shape == (batch_size, 3, 224, 224) assert inputs["features"].dtype == torch.float32 From faef813d59f2d2a31bd3afe5020c0ca2f5a20b0a Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 10:17:35 +0200 Subject: [PATCH 37/74] fix config test --- luxonis_train/config/config.py | 19 ++----------------- luxonis_train/core/core.py | 4 ++-- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index acdca73c..9cadeaa0 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Annotated, Any, Literal, NamedTuple -from aimet_torch.common.defs import QuantizationDataType, QuantScheme from loguru import logger from luxonis_ml.enums import DatasetType from luxonis_ml.typing import ( @@ -743,8 +742,8 @@ class AIMETConfig(BaseModelExtraForbid): default_output_bw: Literal[4, 8, 16] = 8 default_param_bw: Literal[4, 8, 16] = 8 - default_data_type: QuantizationDataType = QuantizationDataType.int - quant_scheme: QuantScheme = QuantScheme.min_max + default_data_type: Literal["int", "float"] = "int" + quant_scheme: Literal["min_max", "tf_enhanced"] = "min_max" config: Params | None = None fold_batch_norms: bool = False @@ -762,20 +761,6 @@ class AIMETConfig(BaseModelExtraForbid): ) ) - @field_validator("quant_scheme", mode="before") - @classmethod - def validate_quant_scheme(cls, value: ParamValue) -> Any: - if isinstance(value, str): - return QuantScheme.from_str(value) - return value - - @field_validator("default_data_type", mode="before") - @classmethod - def validate_default_data_type(cls, value: ParamValue) -> Any: - if isinstance(value, str): - return QuantizationDataType[value.lower()] - return value - @field_validator("config", mode="before") @classmethod def validate_config(cls, value: ParamValue) -> Any: diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index f05bf29a..feb77042 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1346,10 +1346,10 @@ def quantize( dummy_inputs, self.val_loader, save_dir, - quant_scheme or cfg.quant_scheme, + quant_scheme or QuantScheme.from_str(cfg.quant_scheme), default_output_bw or cfg.default_output_bw, default_param_bw or cfg.default_param_bw, - default_data_type or cfg.default_data_type, + default_data_type or QuantizationDataType[cfg.default_data_type], aimet_config_file, adaround, adaround_iterations or cfg.adaround.default_num_iterations, From 51f897de44140dcb51c58dff6f37be7559e357b1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 10:51:12 +0200 Subject: [PATCH 38/74] requirements --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index da5f7e75..315d6685 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,6 @@ semver~=3.0 tabulate~=0.9 tensorboard~=2.20 termcolor~=3.2 -torch<2.11 # temporary pin: GitHub runner image does not yet support the newer NVIDIA driver requirement used by torch 2.11 wheels torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.1 From 5f5ddfad8068b9e94d5db892524c09550bfbab4d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 11:03:35 +0200 Subject: [PATCH 39/74] removed quant ignore --- .../attached_modules/losses/adaptive_detection_loss.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py index f10e1c99..fdcb7cee 100644 --- a/luxonis_train/attached_modules/losses/adaptive_detection_loss.py +++ b/luxonis_train/attached_modules/losses/adaptive_detection_loss.py @@ -1,4 +1,3 @@ -from contextlib import suppress from typing import Literal, cast import torch @@ -297,10 +296,6 @@ def __init__( self.alpha = alpha self.gamma = gamma self.per_class_weights = per_class_weights - with suppress(ImportError): - from aimet_torch.v2.nn import QuantizationMixin - - QuantizationMixin.ignore(self.__class__) def forward( self, pred_score: Tensor, target_score: Tensor, label: Tensor From 4c0d8d6df64fc7e8e7561197526c634520c21266 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 11:06:09 +0200 Subject: [PATCH 40/74] simplify --- luxonis_train/attached_modules/metrics/base_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/base_metric.py b/luxonis_train/attached_modules/metrics/base_metric.py index 5bd55e20..d0e5cec2 100644 --- a/luxonis_train/attached_modules/metrics/base_metric.py +++ b/luxonis_train/attached_modules/metrics/base_metric.py @@ -165,7 +165,7 @@ def compute( return super().compute() def __eq__(self, other: object) -> bool: - return id(self) == id(other) + return self is other def __hash__(self) -> int: return id(self) From b7865957dd4621962c5b9f057b6ba3ea2f7dc30f Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 11:07:43 +0200 Subject: [PATCH 41/74] fix readme --- luxonis_train/callbacks/README.md | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index feaf4b9a..9b4c500e 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -186,6 +186,10 @@ A callback that maintains an exponential moving average (EMA) of the model's par | `use_dynamic_decay` | `bool` | `True` | If enabled, adjusts the decay factor dynamically based on the training iteration. | | `decay_tau` | `float` | `2000` | The time constant (tau) for dynamic decay, influencing how quickly the EMA adapts. | +## `AIMETCallback` + +# Callback to perform AIMET quantization at the end of the training. + ## `TrainingProgressCallback` Callback that publishes training progress and timing metrics. @@ -203,14 +207,3 @@ Callback that publishes training progress and timing metrics. | `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed | | `train/epoch_duration_sec` | Time elapsed so far in current epoch | | `train/epoch_completion_sec` | Total duration of completed epoch in seconds | - -## `AIMETCallback` - -# Callback to perform AIMET quantization at the end of the training. - -| Metric Key | Description | -| ------------------------------ | ------------------------------------------------------- | -| `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed | -| `train/epoch_duration_sec` | Time elapsed so far in current epoch | -| `train/epoch_completion_sec` | Total duration of completed training epoch in seconds | -| `val/epoch_completion_sec` | Total duration of completed validation epoch in seconds | From 8dffdecdee0f9626be58ec038a2fedde1ae3ac7e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 11:11:22 +0200 Subject: [PATCH 42/74] fixed optimizer --- luxonis_train/core/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index feb77042..fa669c6a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1366,12 +1366,11 @@ def quantize( ptq_test = self.pl_trainer.test(model, self.val_loader)[0] if optimizer is None: - opt_cfg = cfg.optimizer or self.cfg.trainer.optimizer optimizer = from_registry( OPTIMIZERS, - opt_cfg.name, + cfg.optimizer.name, params=sim.model.parameters(), - **opt_cfg.params, + **cfg.optimizer.params, ) if scheduler is None: scheduler = from_registry( From 5b09694beee7928e81f0e08b5bf50033e85a0d95 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 13:51:23 +0200 Subject: [PATCH 43/74] req --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 315d6685..83c38355 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,4 @@ termcolor~=3.2 torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.1 -aimet-torch~=2.27 +aimet-torch==2.27 From 67953c498a3a9f4c638a3c97bd1d6ab4426c3beb Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 19:37:15 +0200 Subject: [PATCH 44/74] separated requerements --- luxonis_train/config/config.py | 2 +- luxonis_train/core/core.py | 19 +++++---- .../nodes/backbones/pplcnet_v3/blocks.py | 42 ++++++++++--------- pyproject.toml | 4 +- requirements-aimet.txt | 3 ++ requirements.txt | 1 - 6 files changed, 40 insertions(+), 31 deletions(-) create mode 100644 requirements-aimet.txt diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 9cadeaa0..3784d960 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -743,7 +743,7 @@ class AIMETConfig(BaseModelExtraForbid): default_output_bw: Literal[4, 8, 16] = 8 default_param_bw: Literal[4, 8, 16] = 8 default_data_type: Literal["int", "float"] = "int" - quant_scheme: Literal["min_max", "tf_enhanced"] = "min_max" + quant_scheme: Literal["min_max", "tf", "tf_enhanced"] = "min_max" config: Params | None = None fold_batch_norms: bool = False diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index fa669c6a..52bed6ff 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -12,7 +12,6 @@ import torch import torch.utils.data as torch_data import yaml -from aimet_torch.common.defs import QuantizationDataType, QuantScheme from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.utilities import rank_zero_only from loguru import logger @@ -54,10 +53,6 @@ setup_logging, ) -from .utils.aimet_utils import ( - post_training_quantization, - quantization_aware_training, -) from .utils.annotate_utils import annotate_from_directory from .utils.archive_utils import ( get_head_configs, @@ -1191,11 +1186,11 @@ def quantize( self, weights: PathType | None = None, epochs: int | None = None, - quant_scheme: str | QuantScheme | None = None, + quant_scheme: Literal["min_max", "tf", "tf_enhanced"] | None = None, default_output_bw: int | None = None, default_param_bw: int | None = None, config_file: str | None = None, - default_data_type: QuantizationDataType | None = None, + default_data_type: Literal["int", "float"] | None = None, adaround: bool | None = None, adaround_iterations: int | None = None, adaround_reg_param: float | None = None, @@ -1286,6 +1281,12 @@ def quantize( memory but will overwrite the original model's weights and structure. """ + from aimet_torch.common.defs import QuantizationDataType, QuantScheme + + from .utils.aimet_utils import ( + post_training_quantization, + quantization_aware_training, + ) save_dir = self.run_save_dir / "aimet" save_dir.mkdir(parents=True, exist_ok=True) @@ -1346,10 +1347,10 @@ def quantize( dummy_inputs, self.val_loader, save_dir, - quant_scheme or QuantScheme.from_str(cfg.quant_scheme), + QuantScheme.from_str(quant_scheme or cfg.quant_scheme), default_output_bw or cfg.default_output_bw, default_param_bw or cfg.default_param_bw, - default_data_type or QuantizationDataType[cfg.default_data_type], + QuantizationDataType[default_data_type or cfg.default_data_type], aimet_config_file, adaround, adaround_iterations or cfg.adaround.default_num_iterations, diff --git a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py index e41d6fbf..20913ab1 100644 --- a/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py +++ b/luxonis_train/nodes/backbones/pplcnet_v3/blocks.py @@ -1,5 +1,6 @@ +from contextlib import suppress + import torch -from aimet_torch.v2.nn import QuantizationMixin from torch import Tensor, nn from typeguard import typechecked @@ -32,29 +33,32 @@ def forward(self, x: Tensor) -> Tensor: return self.scale * x + self.bias -@QuantizationMixin.implements(AffineBlock) -class QuantizedAffineBlock(QuantizationMixin, AffineBlock): - def __quant_init__(self): - super().__quant_init__() +with suppress(ImportError): + from aimet_torch.v2.nn import QuantizationMixin - # Declare the number of input/output quantizers - self.input_quantizers = nn.ModuleList([None]) # type: ignore - self.output_quantizers = nn.ModuleList([None]) # type: ignore + @QuantizationMixin.implements(AffineBlock) + class QuantizedAffineBlock(QuantizationMixin, AffineBlock): + def __quant_init__(self): + super().__quant_init__() - def forward(self, x: Tensor) -> Tensor: - # Quantize input tensors - if self.input_quantizers[0]: - x = self.input_quantizers[0](x) + # Declare the number of input/output quantizers + self.input_quantizers = nn.ModuleList([None]) # type: ignore + self.output_quantizers = nn.ModuleList([None]) # type: ignore + + def forward(self, x: Tensor) -> Tensor: + # Quantize input tensors + if self.input_quantizers[0]: + x = self.input_quantizers[0](x) - # Run forward with quantized inputs and parameters - with self._patch_quantized_parameters(): - ret = super().forward(x) + # Run forward with quantized inputs and parameters + with self._patch_quantized_parameters(): + ret = super().forward(x) - # Quantize output tensors - if self.output_quantizers[0]: - ret = self.output_quantizers[0](ret) + # Quantize output tensors + if self.output_quantizers[0]: + ret = self.output_quantizers[0](ret) - return ret + return ret class LCNetV3Block(nn.Module): diff --git a/pyproject.toml b/pyproject.toml index 77b73816..968810f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,10 @@ where = ["."] [tool.setuptools.dynamic] dependencies = { file = ["requirements.txt"] } -optional-dependencies = { dev = { file = ["requirements-dev.txt"] } } version = { attr = "luxonis_train.__version__" } +[tool.setuptools.dynamic.optional-dependencies] +aimet = { file = ["requirements-aimet.txt"] } +dev = { file = ["requirements-dev.txt"] } [tool.ruff] target-version = "py310" diff --git a/requirements-aimet.txt b/requirements-aimet.txt new file mode 100644 index 00000000..5b4207f9 --- /dev/null +++ b/requirements-aimet.txt @@ -0,0 +1,3 @@ +aimet-torch==2.28 +torch==2.11 +torchvision~=0.26 diff --git a/requirements.txt b/requirements.txt index 83c38355..dfb5cd7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,3 @@ termcolor~=3.2 torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.1 -aimet-torch==2.27 From 8ff59a6940752811c173c4b53ac49a74db96402d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 19:38:36 +0200 Subject: [PATCH 45/74] updated ci --- .github/workflows/ci.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ea1c7719..29f5c30d 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -68,7 +68,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e .[dev] + run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu126 - name: Install dev version of LuxonisML if: startsWith(github.head_ref, 'release/') == false @@ -147,7 +147,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e .[dev] + run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu126 - name: Install dev version of LuxonisML if: startsWith(github.head_ref, 'release/') == false From d249b576e447f851e89a123736fafcb3989463ed Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 19:42:43 +0200 Subject: [PATCH 46/74] readme update --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 8a6618a4..91714015 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,15 @@ pip install luxonis-train This will also install the `luxonis_train` CLI. For more information on how to use it, see [CLI Usage](#cli). +### AIMET Quantization Support + +To enable support for AIMET quantization, install the `luxonis-train[aimet]` extra: + +```bash +pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu126 + +``` + ## 📝 Usage From 1b3b8c1faeb52e87262800860e8a9b6f414efd4a Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 8 Apr 2026 19:54:50 +0200 Subject: [PATCH 47/74] fix test --- tests/unittests/test_utils/test_dataset_metadata.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unittests/test_utils/test_dataset_metadata.py b/tests/unittests/test_utils/test_dataset_metadata.py index ee2f7478..ceb75c81 100644 --- a/tests/unittests/test_utils/test_dataset_metadata.py +++ b/tests/unittests/test_utils/test_dataset_metadata.py @@ -29,8 +29,7 @@ def test_n_keypoints(metadata: DatasetMetadata): assert metadata.n_keypoints("color-segmentation") == 0 assert metadata.n_keypoints("detection") == 0 assert metadata.n_keypoints() == 0 - with pytest.raises(ValueError, match="Task 'segmentation'"): - metadata.n_keypoints("segmentation") + assert metadata.n_keypoints("segmentation") == 0 metadata._n_keypoints["segmentation"] = 1 with pytest.raises(RuntimeError, match="different number of keypoints"): metadata.n_keypoints() From bb379ff500d35a409b05c71034d208b405faadc1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 9 Apr 2026 09:01:45 +0200 Subject: [PATCH 48/74] helper --- luxonis_train/core/utils/aimet_utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index cc2a7061..3a637276 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -1,4 +1,5 @@ import math +from importlib.util import find_spec from pathlib import Path from typing import Any, cast @@ -20,6 +21,15 @@ from luxonis_train.loaders.base_loader import LuxonisLoaderTorchOutput +def check_aimet_available() -> None: + if not find_spec("aimet_torch"): + raise ImportError( + "AIMET library is not installed. Please install " + "`luxonis-train` with the `aimet` extra enabled " + "(pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu126)" + ) + + def post_training_quantization( model: LuxonisLightningModule, dummy_inputs: Tensor, From 64764824bcc0b9e178ba65c797030b60d6839e93 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 9 Apr 2026 11:40:10 +0200 Subject: [PATCH 49/74] required aimet fields --- luxonis_train/config/config.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 3784d960..5287f811 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -761,6 +761,33 @@ class AIMETConfig(BaseModelExtraForbid): ) ) + @model_validator(mode="before") + @classmethod + def validate_active(cls, data: Params) -> Params: + if not data.get("active", False): + return data + for required_field in [ + "fold_batch_norms", + "cross_layer_equalization", + "batch_norm_reestimation", + ]: + if required_field not in data: + raise ValueError( + f"AIMET config is active but missing required field '{required_field}'." + ) + adaround = data.get("adaround", {}) + if not isinstance(adaround, dict): + raise TypeError( + f"Invalid type for 'adaround': {type(adaround)}. " + "Expected a dict." + ) + if "active" not in adaround: + raise ValueError( + "AIMET config is active but missing required field " + "'adaround.active'." + ) + return data + @field_validator("config", mode="before") @classmethod def validate_config(cls, value: ParamValue) -> Any: From 34c0fdd6eb8b5b8f97d65ea99d349414c3dded81 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 14:56:39 +0200 Subject: [PATCH 50/74] fix test --- tests/integration/test_callbacks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index ee99e44d..b9e0b9e8 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -63,7 +63,10 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "batch_norm_reestimation": True, "cross_layer_equalization": True, }, - "exporter.aimet.adaround.active": True, + "exporter.aimet.adaround": { + "active": True, + "default_num_iterations": 1, + }, "loader.params.dataset_name": coco_dataset.identifier, } model = LuxonisModel(config_file, opts, debug_mode=True) From ad3e379ed268e236f49f9f2e59159497a1c79442 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 15:34:21 +0200 Subject: [PATCH 51/74] docs --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 91714015..6f10e68b 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,7 @@ The CLI is the most straightforward way how to use `LuxonisTrain`. The CLI provi - `tune` - Tune the hyperparameters of the model for better performance - `inspect` - Inspect the dataset you are using and visualize the annotations - `annotate` - Annotate a directory using the model’s predictions and generate a new LDF. +- `quantize` - Quantize the model using `AIMET` quantization techniques **To get help on any command:** From 06a6c2c0b9be1e285f3013cb4acfc9e563eaa3d1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 17:33:56 +0200 Subject: [PATCH 52/74] fix prediction --- luxonis_train/core/utils/annotate_utils.py | 2 +- luxonis_train/lightning/luxonis_lightning.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/luxonis_train/core/utils/annotate_utils.py b/luxonis_train/core/utils/annotate_utils.py index 184688ae..a7874859 100644 --- a/luxonis_train/core/utils/annotate_utils.py +++ b/luxonis_train/core/utils/annotate_utils.py @@ -82,7 +82,7 @@ def annotated_dataset_generator( for imgs, metas in loader: with torch.no_grad(): - batch_out = lt_module(imgs).outputs + batch_out = lt_module.full_forward(imgs).outputs for head_name, head_output in batch_out.items(): img_paths = [Path(p) for p in metas["/metadata/path"]] diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 1a9d756d..79fb28c2 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -97,7 +97,7 @@ class LuxonisLightningModule(pl.LightningModule): _trainer: pl.Trainer logger: LuxonisTrackerPL - __call__: Callable[..., LuxonisOutput] + __call__: Callable[..., tuple[Tensor, ...]] def __init__( self, @@ -397,7 +397,7 @@ def export_onnx(self, save_path: PathType, **kwargs) -> Path: @override def training_step( - self, train_batch: tuple[dict[str, Tensor], Labels] + self, train_batch: tuple[dict[str, Tensor] | Tensor, Labels] ) -> Tensor: outputs = self.full_forward(*train_batch) if not outputs.losses: @@ -409,22 +409,25 @@ def training_step( @override def validation_step( - self, val_batch: tuple[dict[str, Tensor], Labels] + self, val_batch: tuple[dict[str, Tensor] | Tensor, Labels] ) -> dict[str, Tensor]: return self._evaluation_step("val", *val_batch) @override def test_step( - self, test_batch: tuple[dict[str, Tensor], Labels] + self, test_batch: tuple[dict[str, Tensor] | Tensor, Labels] ) -> dict[str, Tensor]: return self._evaluation_step("test", *test_batch) @override def predict_step( - self, batch: tuple[dict[str, Tensor], Labels] + self, batch: tuple[dict[str, Tensor] | Tensor, Labels] ) -> LuxonisOutput: inputs, labels = batch - images = get_denormalized_images(self.cfg, inputs[self.image_source]) + images = get_denormalized_images( + self.cfg, + inputs[self.image_source] if isinstance(inputs, dict) else inputs, + ) return self.full_forward( inputs, labels, From fc73aec976ddc981a9339b77413d5914fb98aa41 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 18:01:59 +0200 Subject: [PATCH 53/74] fix test --- tests/integration/test_callbacks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index b9e0b9e8..c1396b6d 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -58,7 +58,6 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "exporter.blobconverter.active": True, "exporter.aimet": { "active": True, - "default_num_iterations": 1, "fold_batch_norms": True, "batch_norm_reestimation": True, "cross_layer_equalization": True, From 96f42dad5392847bc1fff3ecc9203b4bbd0892bd Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 18:32:57 +0200 Subject: [PATCH 54/74] fix keypoint and anomaly --- .../losses/efficient_keypoint_bbox_loss.py | 15 ++++++++------- .../losses/reconstruction_segmentation_loss.py | 2 +- luxonis_train/lightning/luxonis_lightning.py | 7 +++++++ luxonis_train/lightning/utils.py | 12 ++++++++++++ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py index 4e27edec..dfbdea71 100644 --- a/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py +++ b/luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from loguru import logger -from torch import Tensor, nn +from torch import Tensor from luxonis_train.attached_modules.losses import AdaptiveDetectionLoss from luxonis_train.nodes import EfficientKeypointBBoxHead @@ -72,9 +72,7 @@ def __init__( **kwargs, ) - self.b_cross_entropy = nn.BCEWithLogitsLoss( - pos_weight=torch.tensor([viz_pw]) - ) + self.pos_weight = torch.tensor([viz_pw]) self.sigmas = get_sigmas( sigmas=sigmas, n_keypoints=self.n_keypoints, caller_name=self.name ) @@ -129,7 +127,7 @@ def forward( scaled_raw_keypoints = keypoints_raw.clone() scaled_raw_keypoints[..., :2] = scaled_raw_keypoints[ ..., :2 - ] * self.stride_tensor.view(1, -1, 1, 1) + ] * self.stride_tensor.clone().view(1, -1, 1, 1) sigmas = self.sigmas.to(device) @@ -195,8 +193,11 @@ def forward( regression_loss = ( ((1 - torch.exp(-e)) * mask).sum(dim=1) / (mask.sum(dim=1) + 1e-9) ).mean() - visibility_loss = self.b_cross_entropy.forward( - keypoints_raw[..., 2], mask + + visibility_loss = F.binary_cross_entropy_with_logits( + keypoints_raw[..., 2], + mask, + pos_weight=self.pos_weight.clone().to(device), ) one_hot_label = F.one_hot(assigned_labels.long(), self.n_classes + 1)[ diff --git a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py index f71cd762..8eb5c1ab 100644 --- a/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py +++ b/luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py @@ -92,7 +92,7 @@ def forward(self, img1: Tensor, img2: Tensor) -> Tensor: (_, channel, _, _) = img1.size() if channel == self.channel and self.window.dtype == img1.dtype: - window = self.window.to(device) + window = self.window.to(device).clone() else: window = ( create_window(self.window_size, channel) diff --git a/luxonis_train/lightning/luxonis_lightning.py b/luxonis_train/lightning/luxonis_lightning.py index 79fb28c2..fba8bff4 100644 --- a/luxonis_train/lightning/luxonis_lightning.py +++ b/luxonis_train/lightning/luxonis_lightning.py @@ -337,6 +337,13 @@ def full_forward( outputs=outputs_dict, losses=losses, visualizations=visualizations ) + @override + def train(self, mode: bool = True) -> Self: + super().train(mode) + for node in self.nodes.values(): + node.train(mode) + return self + def set_export_mode(self, mode: bool) -> Self: for module in self.modules(): if isinstance(module, BaseNode): diff --git a/luxonis_train/lightning/utils.py b/luxonis_train/lightning/utils.py index 9afff579..27c506be 100644 --- a/luxonis_train/lightning/utils.py +++ b/luxonis_train/lightning/utils.py @@ -16,6 +16,7 @@ from torch import Size, Tensor, nn from torch.optim.lr_scheduler import LRScheduler, SequentialLR from torch.optim.optimizer import Optimizer +from typing_extensions import override import luxonis_train as lxt from luxonis_train.attached_modules import BaseLoss, BaseMetric, BaseVisualizer @@ -94,6 +95,17 @@ def __init__( def task_name(self) -> str: return self.module.task_name + @override + def train(self, mode: bool = True) -> "NodeWrapper": + self.module.train(mode) + for loss in self.losses.values(): + loss.train(mode) + for metric in self.metrics.values(): + metric.train(mode) + for visualizer in self.visualizers.values(): + visualizer.train(mode) + return self + class Nodes(dict[str, NodeWrapper] if TYPE_CHECKING else nn.ModuleDict): def __init__( From d5fefe54d275fb268f482a047b39e91069a0f1a7 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 19:42:21 +0200 Subject: [PATCH 55/74] fix deepcopy --- .../attached_modules/visualizers/base_visualizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/visualizers/base_visualizer.py b/luxonis_train/attached_modules/visualizers/base_visualizer.py index 618ed221..565c1e93 100644 --- a/luxonis_train/attached_modules/visualizers/base_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/base_visualizer.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from luxonis_ml.data.utils import ColorMap from torch import Tensor -from typing_extensions import TypeVarTuple, Unpack +from typing_extensions import TypeVarTuple, Unpack, override from luxonis_train.attached_modules import BaseAttachedModule from luxonis_train.registry import VISUALIZERS @@ -26,6 +26,13 @@ def __init__(self, *args, scale: float = 1.0, **kwargs) -> None: super().__init__(*args, **kwargs) self.scale = scale + @override + def __getstate__(self) -> dict: + state = super().__getstate__() + if "colormap" in state: + del state["colormap"] + return state + @staticmethod def scale_canvas(canvas: Tensor, scale: float = 1.0) -> Tensor: return F.interpolate( From 5993a6f554ab83dca4eda64712e6d744d3719965 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 19:56:30 +0200 Subject: [PATCH 56/74] seq mse --- configs/README.md | 1 + luxonis_train/config/config.py | 2 ++ luxonis_train/core/core.py | 10 +++++++++- luxonis_train/core/utils/aimet_utils.py | 15 +++++++++++++++ tests/conftest.py | 12 ++++++++++++ tests/integration/test_callbacks.py | 10 ---------- 6 files changed, 39 insertions(+), 11 deletions(-) diff --git a/configs/README.md b/configs/README.md index 9e3353a8..bd18cd52 100644 --- a/configs/README.md +++ b/configs/README.md @@ -587,6 +587,7 @@ The [AIMET](https://quic.github.io/aimet-pages/releases/latest/index.html) (AI M | `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | | `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | | `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | +| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. | | `optimizer` | `dict` | `{}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. If not set, the `trainer` optimizer is used. | | `scheduler` | `dict` | `{}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples. If not set, the `trainer` scheduler is used. | | `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | diff --git a/luxonis_train/config/config.py b/luxonis_train/config/config.py index 2a4ca1f7..f569cca0 100644 --- a/luxonis_train/config/config.py +++ b/luxonis_train/config/config.py @@ -741,6 +741,7 @@ class AIMETConfig(BaseModelExtraForbid): fold_batch_norms: bool = False cross_layer_equalization: bool = False batch_norm_reestimation: bool = False + sequential_mse: bool = False adaround: AdaroundConfig = Field(default_factory=AdaroundConfig) epochs: PositiveInt = 20 @@ -762,6 +763,7 @@ def validate_active(cls, data: Params) -> Params: "fold_batch_norms", "cross_layer_equalization", "batch_norm_reestimation", + "sequential_mse", ]: if required_field not in data: raise ValueError( diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 52bed6ff..0520f99a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1199,10 +1199,11 @@ def quantize( fold_batch_norms: bool | None = None, cross_layer_equalization: bool | None = None, batch_norm_reestimation: bool | None = None, + sequential_mse: bool | None = None, optimizer: Optimizer | None = None, scheduler: LRScheduler | None = None, in_place: bool = False, - ) -> None: + ) -> Path: """Runs quantization of the model using AIMET. @type weights: PathType | None @@ -1315,6 +1316,11 @@ def quantize( if batch_norm_reestimation is not None else cfg.batch_norm_reestimation ) + sequential_mse = ( + sequential_mse + if sequential_mse is not None + else cfg.sequential_mse + ) if not in_place: model = deepcopy(self.lightning_module) @@ -1360,6 +1366,7 @@ def quantize( fold_batch_norms, cross_layer_equalization, batch_norm_reestimation, + sequential_mse, ) model = cast(LuxonisLightningModule, sim.model) @@ -1412,6 +1419,7 @@ def quantize( table, ["Name", "Pre-Quant", "PTQ", "QAT"], ) + return save_dir @property def environ(self) -> Environ: diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index 3a637276..eba69d74 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -7,7 +7,11 @@ from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters from aimet_torch.bn_reestimation import reestimate_bn_stats from aimet_torch.common.defs import QuantizationDataType, QuantScheme +from aimet_torch.common.quantsim_config.utils import ( + get_path_for_per_channel_config, +) from aimet_torch.cross_layer_equalization import equalize_model +from aimet_torch.seq_mse import apply_seq_mse from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms from lightning.pytorch.accelerators import CUDAAccelerator from loguru import logger @@ -48,6 +52,7 @@ def post_training_quantization( fold_batch_norms: bool = False, cross_layer_equalization: bool = False, batch_norm_reestimation: bool = False, + sequential_mse: bool = False, ) -> QuantizationSimModel: def pass_calibration_data(model: nn.Module) -> None: @@ -101,6 +106,9 @@ def pass_calibration_data(model: nn.Module) -> None: ), ) + if batch_norm_reestimation and config_file is None: + config_file = get_path_for_per_channel_config() + sim = QuantizationSimModel( model=model, dummy_input=dummy_inputs, @@ -111,6 +119,13 @@ def pass_calibration_data(model: nn.Module) -> None: default_data_type=default_data_type, in_place=True, ) + if sequential_mse: + logger.info("Applying sequential MSE") + apply_seq_mse( + sim, + data_loader=(imgs for imgs, _ in val_loader), # type: ignore + num_candidates=20, + ) if adaround: sim.set_and_freeze_param_encodings( diff --git a/tests/conftest.py b/tests/conftest.py index 5fab5f96..12a37ed7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -405,6 +405,18 @@ def opts(save_dir: Path, image_size: tuple[int, int]) -> Params: ], "tracker.save_directory": str(save_dir), "trainer.preprocessing.train_image_size": image_size, + "exporter.aimet": { + "active": True, + "epochs": 1, + "fold_batch_norms": True, + "batch_norm_reestimation": True, + "cross_layer_equalization": True, + "sequential_mse": True, + }, + "exporter.aimet.adaround": { + "active": True, + "default_num_iterations": 1, + }, } diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index c1396b6d..afea91c0 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -56,16 +56,6 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "exporter.scale_values": [0.5, 0.5, 0.5], "exporter.mean_values": [0.5, 0.5, 0.5], "exporter.blobconverter.active": True, - "exporter.aimet": { - "active": True, - "fold_batch_norms": True, - "batch_norm_reestimation": True, - "cross_layer_equalization": True, - }, - "exporter.aimet.adaround": { - "active": True, - "default_num_iterations": 1, - }, "loader.params.dataset_name": coco_dataset.identifier, } model = LuxonisModel(config_file, opts, debug_mode=True) From 24ac1508871358f7d24d984aea316e61b40891be Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 22:52:33 +0200 Subject: [PATCH 57/74] fix tests --- .../visualizers/embeddings_visualizer.py | 6 ++++-- luxonis_train/core/utils/aimet_utils.py | 19 ++++++++++++------- tests/integration/backbone_model_utils.py | 2 ++ tests/integration/test_predefined_models.py | 6 ++++++ 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py index d9f4cd7b..4d301577 100644 --- a/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py +++ b/luxonis_train/attached_modules/visualizers/embeddings_visualizer.py @@ -51,6 +51,7 @@ def forward( @return: An embedding space projection. """ embeddings_np = predictions.detach().cpu().numpy() + embeddings_np[np.isnan(embeddings_np) | np.isinf(embeddings_np)] = 0.0 ids_np = target.detach().cpu().numpy().astype(int) pca = PCA(n_components=2, random_state=42) @@ -90,8 +91,9 @@ def plot_to_tensor( plot_func: Callable[[plt.Axes, np.ndarray, np.ndarray], None], ) -> Tensor: fig, ax = plt.subplots(figsize=(10, 10)) - ax.set_xlim(embeddings_2d[:, 0].min(), embeddings_2d[:, 0].max()) - ax.set_ylim(embeddings_2d[:, 1].min(), embeddings_2d[:, 1].max()) + if embeddings_2d.size > 0: + ax.set_xlim(embeddings_2d[:, 0].min(), embeddings_2d[:, 0].max()) + ax.set_ylim(embeddings_2d[:, 1].min(), embeddings_2d[:, 1].max()) plot_func(ax, embeddings_2d, ids_np) ax.axis("off") diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index eba69d74..a4c25085 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -121,10 +121,12 @@ def pass_calibration_data(model: nn.Module) -> None: ) if sequential_mse: logger.info("Applying sequential MSE") + apply_seq_mse( sim, - data_loader=(imgs for imgs, _ in val_loader), # type: ignore + data_loader=val_loader, num_candidates=20, + forward_fn=_patched_forward_pass, ) if adaround: @@ -169,12 +171,9 @@ def quantization_aware_training( if batch_norm_reestimation: logger.info("Reestimating batch norm statistics") - def _forward_pass( - model: nn.Module, inputs: LuxonisLoaderTorchOutput - ) -> Any: - return model(inputs[0]) - - reestimate_bn_stats(model, train_loader, forward_fn=_forward_pass) + reestimate_bn_stats( + model, train_loader, forward_fn=_patched_forward_pass + ) if fold_batch_norms: logger.info("Folding batch norms into preceding layers") @@ -186,3 +185,9 @@ def _forward_pass( model.automatic_optimization = True return model + + +def _patched_forward_pass( + model: nn.Module, inputs: LuxonisLoaderTorchOutput +) -> Any: + return model(inputs[0]) diff --git a/tests/integration/backbone_model_utils.py b/tests/integration/backbone_model_utils.py index b3a3ddf6..725d4628 100644 --- a/tests/integration/backbone_model_utils.py +++ b/tests/integration/backbone_model_utils.py @@ -76,5 +76,7 @@ def prepare_predefined_model_config( } elif "ocr_recognition" in config_file: opts["trainer.preprocessing.train_image_size"] = [48, 320] + elif "instance_segmentation" in config_name: + opts |= {"exporter.aimet.batch_norm_reestimation": False} return config_file, opts, dataset diff --git a/tests/integration/test_predefined_models.py b/tests/integration/test_predefined_models.py index 6e89edba..2c6f0928 100644 --- a/tests/integration/test_predefined_models.py +++ b/tests/integration/test_predefined_models.py @@ -65,6 +65,12 @@ def test_predefined_models( model.run_save_dir / "archive" / f"{config_name}.onnx.tar.xz" ).exists() + with subtests.test("quantize"): + save_dir = model.quantize() + assert (save_dir / f"{config_name}.encodings").exists() + assert (save_dir / f"{config_name}.onnx").exists() + assert (save_dir / f"{config_name}.onnx.data").exists() + if config_name != "embeddings_model": with subtests.test("infer"): loader = LuxonisLoader(dataset) From 478046c1362968047552408198533cbed441c954 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 23:08:34 +0200 Subject: [PATCH 58/74] updated import --- luxonis_train/core/utils/aimet_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/core/utils/aimet_utils.py b/luxonis_train/core/utils/aimet_utils.py index a4c25085..8c55c363 100644 --- a/luxonis_train/core/utils/aimet_utils.py +++ b/luxonis_train/core/utils/aimet_utils.py @@ -5,6 +5,7 @@ from aimet_torch import QuantizationSimModel from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters +from aimet_torch.batch_norm_fold import fold_all_batch_norms from aimet_torch.bn_reestimation import reestimate_bn_stats from aimet_torch.common.defs import QuantizationDataType, QuantScheme from aimet_torch.common.quantsim_config.utils import ( @@ -12,7 +13,6 @@ ) from aimet_torch.cross_layer_equalization import equalize_model from aimet_torch.seq_mse import apply_seq_mse -from aimet_torch.v1.batch_norm_fold import fold_all_batch_norms from lightning.pytorch.accelerators import CUDAAccelerator from loguru import logger from rich.progress import track From 3fc8cb4f0f08b933f251b7b5020fb61d246d9927 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Mon, 13 Apr 2026 23:39:47 +0200 Subject: [PATCH 59/74] fix test --- luxonis_train/nodes/backbones/dinov3/dinov3.py | 1 + luxonis_train/nodes/backbones/efficientnet.py | 1 + luxonis_train/utils/dataset_metadata.py | 6 ++++-- tests/conftest.py | 2 +- tests/integration/test_callbacks.py | 2 ++ tests/integration/test_combinations.py | 4 +++- tests/integration/test_custom_model.py | 2 +- tests/integration/test_predefined_models.py | 4 +++- 8 files changed, 16 insertions(+), 6 deletions(-) diff --git a/luxonis_train/nodes/backbones/dinov3/dinov3.py b/luxonis_train/nodes/backbones/dinov3/dinov3.py index 9e88e3cc..20fcf813 100644 --- a/luxonis_train/nodes/backbones/dinov3/dinov3.py +++ b/luxonis_train/nodes/backbones/dinov3/dinov3.py @@ -206,6 +206,7 @@ def _get_backbone( model=model_name, weights=weights, source="github", + trust_repo=True, # type: ignore **kwargs, ) diff --git a/luxonis_train/nodes/backbones/efficientnet.py b/luxonis_train/nodes/backbones/efficientnet.py index d5b74728..c92c746b 100644 --- a/luxonis_train/nodes/backbones/efficientnet.py +++ b/luxonis_train/nodes/backbones/efficientnet.py @@ -43,6 +43,7 @@ class GenEfficientNet(nn.Module): "rwightman/gen-efficientnet-pytorch", "efficientnet_lite0", pretrained=weights == "download", + trust_repo=True, # type: ignore ), ) self.out_indices = out_indices or [0, 1, 2, 4, 6] diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 24a888fa..53f15f34 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -103,7 +103,8 @@ def n_classes(self, task_name: str | None = None) -> int: if task_name is not None: if task_name not in self._classes: raise ValueError( - f"Task '{task_name}' is not present in the dataset." + f"Task '{task_name}' is not present in the dataset. " + f"Available tasks: {self.task_names}" ) return len(self._classes[task_name]) n_classes = len(next(iter(self._classes.values()))) @@ -155,7 +156,8 @@ def classes(self, task_name: str | None = None) -> bidict[str, int]: if task_name is not None: if task_name not in self._classes: raise ValueError( - f"Task '{task_name}' is not present in the dataset." + f"Task '{task_name}' is not present in the dataset. " + f"Available tasks: {self.task_names}" ) return bidict(self._classes[task_name]) classes = next(iter(self._classes.values())) diff --git a/tests/conftest.py b/tests/conftest.py index 12a37ed7..a68d8f08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -406,7 +406,7 @@ def opts(save_dir: Path, image_size: tuple[int, int]) -> Params: "tracker.save_directory": str(save_dir), "trainer.preprocessing.train_image_size": image_size, "exporter.aimet": { - "active": True, + "active": False, "epochs": 1, "fold_batch_norms": True, "batch_norm_reestimation": True, diff --git a/tests/integration/test_callbacks.py b/tests/integration/test_callbacks.py index afea91c0..1379ba7b 100644 --- a/tests/integration/test_callbacks.py +++ b/tests/integration/test_callbacks.py @@ -56,6 +56,8 @@ def test_callbacks(coco_dataset: LuxonisDataset, opts: Params, save_dir: Path): "exporter.scale_values": [0.5, 0.5, 0.5], "exporter.mean_values": [0.5, 0.5, 0.5], "exporter.blobconverter.active": True, + # AIMET fails when determinism is enabled + "exporter.aimet.active": False, "loader.params.dataset_name": coco_dataset.identifier, } model = LuxonisModel(config_file, opts, debug_mode=True) diff --git a/tests/integration/test_combinations.py b/tests/integration/test_combinations.py index 2c3048b6..d9a5b482 100644 --- a/tests/integration/test_combinations.py +++ b/tests/integration/test_combinations.py @@ -143,7 +143,9 @@ def test_combinations( subtests: SubTests, ): config = get_config(backbone, dinov3_weights) - opts |= {"loader.params.dataset_name": parking_lot_dataset.identifier} + opts |= { + "loader.params.dataset_name": parking_lot_dataset.identifier, + } model = LuxonisModel(config, opts) with subtests.test("train"): diff --git a/tests/integration/test_custom_model.py b/tests/integration/test_custom_model.py index c4915bc5..8c2cc8f1 100644 --- a/tests/integration/test_custom_model.py +++ b/tests/integration/test_custom_model.py @@ -30,7 +30,7 @@ def input_shapes(self): "pointcloud": torch.Size([self.n_points, 3]), } - def get(self, _: int) -> LuxonisLoaderTorchOutput: + def __getitem__(self, _: int) -> LuxonisLoaderTorchOutput: left = torch.rand(3, self.height, self.width, dtype=torch.float32) right = torch.rand(3, self.height, self.width, dtype=torch.float32) disparity = torch.rand(1, self.height, self.width, dtype=torch.float32) diff --git a/tests/integration/test_predefined_models.py b/tests/integration/test_predefined_models.py index 2c6f0928..143e4957 100644 --- a/tests/integration/test_predefined_models.py +++ b/tests/integration/test_predefined_models.py @@ -48,7 +48,9 @@ def test_predefined_models( tmp_path = tmp_path / config_name tmp_path.mkdir() - model = LuxonisModel(config_file, opts | extra_opts) + model = LuxonisModel( + config_file, opts | extra_opts | {"exporter.aimet.active": True} + ) with subtests.test("train"): model.train() From e7b4a6e90d3a2d02d4b0afdc09627116b0a0cd28 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 14 Apr 2026 06:20:06 +0200 Subject: [PATCH 60/74] updated svg --- media/anomaly_detection_diagram.drawio | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/media/anomaly_detection_diagram.drawio b/media/anomaly_detection_diagram.drawio index ee134024..7411296c 100644 --- a/media/anomaly_detection_diagram.drawio +++ b/media/anomaly_detection_diagram.drawio @@ -113,7 +113,7 @@ - + From a5558a495a85b85a4d92a214ba5a9b8776bea81e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 12 May 2026 12:42:39 +0200 Subject: [PATCH 61/74] fix test --- .../metrics/confusion_matrix/detection_confusion_matrix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py index 0f4d774e..c160e6c1 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py @@ -50,6 +50,7 @@ def compute(self) -> dict[str, Tensor]: "confusion_matrix": self.confusion_matrix, } + @torch.inference_mode() def _update(self, predictions: list[Tensor], targets: Tensor) -> None: for pred, target in zip( predictions, From d14f285d5a2e88423b401205401d02005bc2373d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 12 May 2026 21:38:16 +0200 Subject: [PATCH 62/74] fix --- .../metrics/confusion_matrix/detection_confusion_matrix.py | 1 - requirements.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py index c160e6c1..0f4d774e 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py @@ -50,7 +50,6 @@ def compute(self) -> dict[str, Tensor]: "confusion_matrix": self.confusion_matrix, } - @torch.inference_mode() def _update(self, predictions: list[Tensor], targets: Tensor) -> None: for pred, target in zip( predictions, diff --git a/requirements.txt b/requirements.txt index 44546ef2..453c1d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ termcolor~=3.2 torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.5 +torch<2.11 From 70f6675aeb6d42c0e89d4a4d9c1d81a473ec533a Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 19 May 2026 18:13:47 +0200 Subject: [PATCH 63/74] fix readme --- configs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/README.md b/configs/README.md index ace053fb..4f2446dd 100644 --- a/configs/README.md +++ b/configs/README.md @@ -510,7 +510,7 @@ Here you can define configuration for exporting. | `onnx` | `dict` | `{}` | Options specific for ONNX export. See [ONNX](#onnx) section for details | | `hubai` | `dict` | `{}` | Options for HubAI SDK conversion. See [HubAI](#hubai) section for details | | `blobconverter` | `dict` | `{}` | Options for converting to BLOB format (deprecated). See [Blob](#blob-deprecated) section | -| `aimet` | | `dict` | `{}` | +| `aimet` | `dict` | `{}` | Options for AIMET quantization. See [AIMET](#aimet)\] | ### `ONNX` From ed6ed80efb6973eef31246972d720b51d7326f1e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 19 May 2026 18:16:51 +0200 Subject: [PATCH 64/74] fix link --- configs/README.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/configs/README.md b/configs/README.md index 4f2446dd..438d3186 100644 --- a/configs/README.md +++ b/configs/README.md @@ -576,22 +576,22 @@ exporter: The [AIMET](https://quic.github.io/aimet-pages/releases/latest/index.html) (AI Model Efficiency Toolkit) provides quantization and model export tools. -| Key | Type | Default value | Description | -| -------------------------- | ------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `active` | `bool` | `False` | Whether to use AIMET for quantization and export | -| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training | -| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | -| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | -| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | -| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | -| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/docs/QuantizationSim.html#quantization-sim-config) for details on the available options. | -| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | -| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | -| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | -| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. | -| `optimizer` | `dict` | `{}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. If not set, the `trainer` optimizer is used. | -| `scheduler` | `dict` | `{}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples. If not set, the `trainer` scheduler is used. | -| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | +| Key | Type | Default value | Description | +| -------------------------- | ------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `active` | `bool` | `False` | Whether to use AIMET for quantization and export | +| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training | +| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | +| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | +| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | +| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | +| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/techniques/runtime_config.html) for details on the available options. | +| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | +| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | +| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | +| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. | +| `optimizer` | `dict` | `{}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. If not set, the `trainer` optimizer is used. | +| `scheduler` | `dict` | `{}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples. If not set, the `trainer` scheduler is used. | +| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | #### Adaround From 1d110ce626c6905ed7d59ab94a406f559bcf5476 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 19 May 2026 18:19:58 +0200 Subject: [PATCH 65/74] update --- configs/README.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/configs/README.md b/configs/README.md index 438d3186..301b64e1 100644 --- a/configs/README.md +++ b/configs/README.md @@ -576,22 +576,22 @@ exporter: The [AIMET](https://quic.github.io/aimet-pages/releases/latest/index.html) (AI Model Efficiency Toolkit) provides quantization and model export tools. -| Key | Type | Default value | Description | -| -------------------------- | ------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `active` | `bool` | `False` | Whether to use AIMET for quantization and export | -| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training | -| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | -| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | -| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | -| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | -| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/techniques/runtime_config.html) for details on the available options. | -| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | -| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | -| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | -| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. | -| `optimizer` | `dict` | `{}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. If not set, the `trainer` optimizer is used. | -| `scheduler` | `dict` | `{}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples. If not set, the `trainer` scheduler is used. | -| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | +| Key | Type | Default value | Description | +| -------------------------- | ------------------------------------------------- | -------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `active` | `bool` | `False` | Whether to use AIMET for quantization and export | +| `epochs` | `int` | `20` | Number of epochs to use for quantization-aware training | +| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | +| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | +| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | +| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | +| `config` | `dict \| str` | `{}` | Additional configuration for AIMET. Can be a dictionary or a path to a JSON config file. Refer to the [AIMET documentation](https://quic.github.io/aimet-pages/releases/latest/techniques/runtime_config.html) for details on the available options. | +| `fold_batch_norms` | `bool` | `False` | Whether to fold batch normalization layers before quantization | +| `cross_layer_equalization` | `bool` | `False` | Whether to perform cross-layer equalization before quantization | +| `batch_norm_reestimation` | `bool` | `False` | Whether to perform batch norm re-estimation after quantization | +| `sequential_mse` | `bool` | `False` | Whether to perform sequential MSE optimization. | +| `optimizer` | `dict` | `{"name": "SGD", "params": {"lr": 1e-5}}` | Optimizer configuration for quantization-aware training. See [Optimizer](#optimizer) section for details and examples. | +| `scheduler` | `dict` | `{"name": "StepLR", "params": {"step_size": 5, "gamma": 0.1}}` | Scheduler configuration for quantization-aware training. See [Scheduler](#scheduler) section for details and examples.. | +| `adaround` | `dict` | `{}` | Configuration for Adaround weight rounding. See [Adaround](#adaround) for more details. | #### Adaround From c06d0029c831d65cc2383cd73b23abe959cc6c17 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Tue, 19 May 2026 18:32:17 +0200 Subject: [PATCH 66/74] removed requirement --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 453c1d3e..44546ef2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,3 @@ termcolor~=3.2 torchmetrics~=1.8 torchvision~=0.24 hubai-sdk>=0.2.5 -torch<2.11 From 0162f88c8245d0b092d16fccf1f85535857fa268 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Wed, 20 May 2026 10:51:46 +0200 Subject: [PATCH 67/74] fix tests --- .../metrics/confusion_matrix/detection_confusion_matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py index 0f4d774e..c102e921 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py @@ -58,6 +58,8 @@ def _update(self, predictions: list[Tensor], targets: Tensor) -> None: ): pred_classes = pred[:, 5].int() target_classes = target[:, 0].int() + if self.confusion_matrix.is_inference(): + self.confusion_matrix = self.confusion_matrix.clone() if target.numel() == pred.numel() == 0: self.confusion_matrix[self.n_classes, self.n_classes] += 1 From bb7f28caafc80c2ed75a160a78e3aa0d33e8636f Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 12:27:25 +0200 Subject: [PATCH 68/74] change requirements --- .github/workflows/ci.yaml | 4 ++-- README.md | 2 +- requirements-aimet.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 29f5c30d..4043c364 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -68,7 +68,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu126 + run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu130 - name: Install dev version of LuxonisML if: startsWith(github.head_ref, 'release/') == false @@ -147,7 +147,7 @@ jobs: cache: pip - name: Install dependencies - run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu126 + run: pip install -e .[dev,aimet] --extra-index-url https://download.pytorch.org/whl/cu130 - name: Install dev version of LuxonisML if: startsWith(github.head_ref, 'release/') == false diff --git a/README.md b/README.md index 2c5d1bd7..948e939d 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ This will also install the `luxonis_train` CLI. For more information on how to u To enable support for AIMET quantization, install the `luxonis-train[aimet]` extra: ```bash -pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu126 +pip install luxonis-train[aimet] --extra-index-url https://download.pytorch.org/whl/cu130 ``` diff --git a/requirements-aimet.txt b/requirements-aimet.txt index 5b4207f9..78ab17fb 100644 --- a/requirements-aimet.txt +++ b/requirements-aimet.txt @@ -1,3 +1,3 @@ -aimet-torch==2.28 +aimet-torch~=2.31 torch==2.11 torchvision~=0.26 From 25656ab6ab3fdacb48a73cff3ef91f067a3a84b1 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:39:32 +0200 Subject: [PATCH 69/74] fix test --- tests/integration/test_predefined_models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/integration/test_predefined_models.py b/tests/integration/test_predefined_models.py index 8061d647..145b8b8c 100644 --- a/tests/integration/test_predefined_models.py +++ b/tests/integration/test_predefined_models.py @@ -48,9 +48,7 @@ def test_predefined_models( tmp_path = tmp_path / config_name tmp_path.mkdir() - model = LuxonisModel( - config_file, opts | extra_opts | {"exporter.aimet.active": True} - ) + model = LuxonisModel(config_file, opts | extra_opts) with subtests.test("train"): model.train() From 7bcd04b0e3056ae60f9c6a3cd074ffe29d5ff4fd Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:46:02 +0200 Subject: [PATCH 70/74] fix readme --- configs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/README.md b/configs/README.md index 301b64e1..6e3e4202 100644 --- a/configs/README.md +++ b/configs/README.md @@ -510,7 +510,7 @@ Here you can define configuration for exporting. | `onnx` | `dict` | `{}` | Options specific for ONNX export. See [ONNX](#onnx) section for details | | `hubai` | `dict` | `{}` | Options for HubAI SDK conversion. See [HubAI](#hubai) section for details | | `blobconverter` | `dict` | `{}` | Options for converting to BLOB format (deprecated). See [Blob](#blob-deprecated) section | -| `aimet` | `dict` | `{}` | Options for AIMET quantization. See [AIMET](#aimet)\] | +| `aimet` | `dict` | `{}` | Options for AIMET quantization. See [AIMET](#aimet) | ### `ONNX` From c3090127260e343de16528a62d2e41be98ca60f6 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:46:08 +0200 Subject: [PATCH 71/74] moved outside of loop --- .../metrics/confusion_matrix/detection_confusion_matrix.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py index c102e921..35246dfc 100644 --- a/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py +++ b/luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py @@ -51,6 +51,9 @@ def compute(self) -> dict[str, Tensor]: } def _update(self, predictions: list[Tensor], targets: Tensor) -> None: + if self.confusion_matrix.is_inference(): + self.confusion_matrix = self.confusion_matrix.clone() + for pred, target in zip( predictions, instances_from_batch(targets, batch_size=len(predictions)), @@ -58,8 +61,6 @@ def _update(self, predictions: list[Tensor], targets: Tensor) -> None: ): pred_classes = pred[:, 5].int() target_classes = target[:, 0].int() - if self.confusion_matrix.is_inference(): - self.confusion_matrix = self.confusion_matrix.clone() if target.numel() == pred.numel() == 0: self.confusion_matrix[self.n_classes, self.n_classes] += 1 From a1fb63bfc64de1f8a82bbc2acf9dd32734334f7d Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:46:14 +0200 Subject: [PATCH 72/74] loading weights before eval --- luxonis_train/core/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index 417fac1e..abeccc5a 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1428,11 +1428,12 @@ def quantize( model = self.lightning_module model.reparametrize().eval() - pre_quant_test = self.pl_trainer.test(model, self.val_loader)[0] if weights is not None: model.load_checkpoint(weights) + pre_quant_test = self.pl_trainer.test(model, self.val_loader)[0] + dummy_inputs = { input_name: torch.randn([1, *shape]).to(model.device) for shapes in model.nodes.loader_input_shapes.values() From 4dc9884ddb7a2817e56bc6fe227b76fadd49870e Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:48:40 +0200 Subject: [PATCH 73/74] safer values --- luxonis_train/core/core.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py index abeccc5a..5fb4ea64 100644 --- a/luxonis_train/core/core.py +++ b/luxonis_train/core/core.py @@ -1460,10 +1460,16 @@ def quantize( QuantizationDataType[default_data_type or cfg.default_data_type], aimet_config_file, adaround, - adaround_iterations or cfg.adaround.default_num_iterations, - adaround_reg_param or cfg.adaround.default_reg_param, + adaround_iterations + if adaround_iterations is not None + else cfg.adaround.default_num_iterations, + adaround_reg_param + if adaround_reg_param is not None + else cfg.adaround.default_reg_param, adaround_beta_range or cfg.adaround.default_beta_range, - adaround_warm_start or cfg.adaround.default_warm_start, + adaround_warm_start + if adaround_warm_start is not None + else cfg.adaround.default_warm_start, fold_batch_norms, cross_layer_equalization, batch_norm_reestimation, @@ -1495,7 +1501,7 @@ def quantize( self.train_loader, optimizer, scheduler, - epochs or cfg.epochs, + epochs if epochs is not None else cfg.epochs, fold_batch_norms, batch_norm_reestimation, ).eval() From 70faf84608564e7303e7efe5a44be6d8006a09c0 Mon Sep 17 00:00:00 2001 From: Martin Kozlovsky Date: Thu, 21 May 2026 18:52:35 +0200 Subject: [PATCH 74/74] fix docs --- luxonis_train/__main__.py | 9 +++++++ luxonis_train/callbacks/README.md | 10 ++++++- .../callbacks/luxonis_progress_bar.py | 26 +++++++------------ luxonis_train/utils/dataset_metadata.py | 5 ++-- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py index abeb0fb9..271917bf 100644 --- a/luxonis_train/__main__.py +++ b/luxonis_train/__main__.py @@ -490,6 +490,15 @@ def quantize( config: str | None = None, weights: str | None = None, ): + """Quantize the model using AIMET. + + @type config: str + @param config: Path to the configuration file. + @type weights: str + @param weights: Path to the model weights. + @type opts: list[str] + @param opts: A list of optional CLI overrides of the config file. + """ model = create_model( config, opts, weights=weights, allow_empty_dataset=True ) diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md index 9b4c500e..c45d59b5 100644 --- a/luxonis_train/callbacks/README.md +++ b/luxonis_train/callbacks/README.md @@ -188,7 +188,15 @@ A callback that maintains an exponential moving average (EMA) of the model's par ## `AIMETCallback` -# Callback to perform AIMET quantization at the end of the training. +Callback to perform AIMET quantization at the end of the training. + +This callback runs AIMET post-training static quantization using the best checkpoint (by default based on the main metric) at the end of training. + +**Parameters:** + +| Key | Type | Default value | Description | +| ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped | ## `TrainingProgressCallback` diff --git a/luxonis_train/callbacks/luxonis_progress_bar.py b/luxonis_train/callbacks/luxonis_progress_bar.py index 1e917ba1..731e6e49 100644 --- a/luxonis_train/callbacks/luxonis_progress_bar.py +++ b/luxonis_train/callbacks/luxonis_progress_bar.py @@ -70,8 +70,8 @@ def print_table( @type title: str @param title: Title of the table - @type table: Mapping[str, int | str | float] - @param table: Table to print + @type table: Iterable[tuple[str | int | float, ...]] + @param table: Table to print as an iterable of rows, where each row is a tuple of values. @type column_names: list[str] @param column_names: Names of the columns in the table """ @@ -181,13 +181,10 @@ def print_table( @type title: str @param title: Title of the table - @type table: Mapping[str, int | str | float] - @param table: Table to print - @type key_name: str - @param key_name: Name of the key column. Defaults to C{"Name"}. - @type value_name: str - @param value_name: Name of the value column. Defaults to - C{"Value"}. + @type table: Iterable[tuple[str | int | float, ...]] + @param table: Table to print as an iterable of rows, where each row is a tuple of values. + @type column_names: list[str] + @param column_names: Names of the columns in the table """ self._rule(title) formatted = tabulate( @@ -332,13 +329,10 @@ def print_table( @type title: str @param title: Title of the table - @type table: Mapping[str, int | str | float] - @param table: Table to print - @type key_name: str - @param key_name: Name of the key column. Defaults to C{"Name"}. - @type value_name: str - @param value_name: Name of the value column. Defaults to - C{"Value"}. + @type table: Iterable[tuple[str | int | float, ...]] + @param table: Table to print as an iterable of rows, where each row is a tuple of values. + @type column_names: list[str] + @param column_names: Names of the columns in the table @param console: Console instance to use, if None use default console. Defaults to None. @type console: Console | None diff --git a/luxonis_train/utils/dataset_metadata.py b/luxonis_train/utils/dataset_metadata.py index 2b54eb82..7691e65b 100644 --- a/luxonis_train/utils/dataset_metadata.py +++ b/luxonis_train/utils/dataset_metadata.py @@ -122,9 +122,8 @@ def n_keypoints(self, task_name: str | None = None) -> int: @type task_name: str | None @param task_name: Task to get the number of keypoints for. @rtype: int - @return: Number of keypoints for the specified task type. - @raises ValueError: If the C{task} is not present in the - dataset. + @return: Number of keypoints for the specified task type or 0 if + the task does not involve keypoints. @raises RuntimeError: If the C{task} was not provided and the dataset contains different number of keypoints for different task types.