Skip to content

AIMET Integration#366

Open
kozlov721 wants to merge 78 commits into
mainfrom
feature/aimet
Open

AIMET Integration#366
kozlov721 wants to merge 78 commits into
mainfrom
feature/aimet

Conversation

@kozlov721
Copy link
Copy Markdown
Collaborator

@kozlov721 kozlov721 commented Apr 8, 2026

Purpose

Adds option to quantize trained model using various PTQ and QAT techniques.

Specification

  • Added LuxonisModel.quantize method
    • Runs selected algorithms on the model weights and exports it to ONNX
    • Support for various PQT techniques:
      • Adaround
      • Batch norm folding
      • Cross-layer equalization
      • Batch norm re-estimation
      • Sequential MSE
  • Added AIMETCallback
    • Runs quantization at the end of the training
  • Added a new section exporter.aimet to the Config
  • Other small changes necessary to accomodate the AIMET API or to simplify the integration
    • Added more standard forward method to LuxonixLightningModule
      • The original renamed to full_forward
    • Support for loaders outputting bare tensors
      • As opposed to the expected dictionary
    • Removed metrics, losses and visuazilers from saved node modules
    • Added print_table as a required abstract method to BaseLuxonisProgressBar
    • Implemented custom __getstate__ and __setstate__ in LuxonisLightningModule

config.yaml

# Default values
exporter:
  aimet:
    active: false

    default_output_bw: 8
    default_param_bw: 8
    default_data_type: "int"
    quant_scheme: "min_max"
    config: ~

    fold_batch_norms: false
    cross_layer_equalization: false
    batch_norm_reestimation: false
    sequential_mse: false
    adaround:
      active: false
      default_num_iterations: ~
      default_reg_param: 0.01
      default_beta_range: [20, 2]
      default_warm_start: 0.2

    epochs: 20
    optimizer:
      name: "SGD"
      params:
        lr: 0.00001
    scheduler:
      name: "StepLR"
      params:
        step_size: 5
        gamma: 0.1

Dependencies & Potential Impact

None / not applicable

Deployment Plan

None / not applicable

Testing & Validation

  • Testing AIMETCallback in test_callbacks
  • Testing LuxonisModel.quantize for all predefined models with the full set of PTQ techniques enabled

Summary by CodeRabbit

  • New Features

    • AIMET quantization added (PTQ & QAT) with a new quantize CLI command and model.quantize API; optional AIMET install available.
    • Model init accepts command-line weights.
  • Bug Fixes & Improvements

    • Loaders accept either a tensor or image-keyed dict for inputs.
    • More robust export/pickling and progress-bar table rendering.
    • Improved numeric stability in visualizers and metric/loader behaviors.
  • Documentation

    • Installation and CLI docs updated to document AIMET/quantize.

Review Change Stack

@kozlov721 kozlov721 force-pushed the feature/aimet branch 3 times, most recently from 9046d74 to 7a5d773 Compare April 14, 2026 03:40
@klemen1999 klemen1999 requested a review from dtronmans May 15, 2026 08:24
Copy link
Copy Markdown
Contributor

@dtronmans dtronmans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments from reading it (did not test locally yet), could you also please include in the PR comment a few notes/justifications on the changes that are not necessarily related to the PR? This would be related to the changes to the losses, backbones, visualizers etc... Just to know what is in the scope of AIMET and what is not

Comment thread configs/README.md Outdated
Comment thread configs/README.md Outdated
**cfg.scheduler.params,
)

model = quantization_aware_training(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to disable QAT and only run PTQ in the config? Right now it seems as though QAT is unconditional

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can disable it by setting the number of epochs to 0.

Comment thread luxonis_train/config/config.py
Comment thread configs/README.md Outdated
| `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights |
| `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters |
| `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values |
| `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quant_scheme values mentioned in the config are quant_scheme: Literal["min_max", "tf", "tf_enhanced"] = "min_max"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"min_max" and "tf" is the same, "tf" is a deprecated name so I'm not mentioning it in the README but I made the config still accept it.

Comment thread luxonis_train/core/core.py
def validate_active(cls, data: Params) -> Params:
if not data.get("active", False):
return data
for required_field in [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for some parameters there is a mismatch between the validator, the config and the signature in the quantize method. Because fold_batch_norms in the config is by default False, but here it is checking whether or not it is set (even though it doesn't necessarily need to be set since there is a default value). And in the quantize() method the fold_batch_norms has default None in the signature

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional:

What if we just make this a required parameter is PTQ is active?

You mean the opt-in advanced techniques? Yeah I think this would make sense. What you want to use usually also depends on the specific model so as long as the entire AIMET is opt-in I think forcing the user to explicitly specify what to use is a good idea.

We decided that for the near future if you enable AIMET you have to explicitly enable or disable each technique you want to use.

Comment thread requirements.txt
Copy link
Copy Markdown
Collaborator Author

@kozlov721 kozlov721 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you also please include in the PR comment a few notes/justifications on the changes that are not necessarily related to the PR?

Mostly all the changes are somehow related to making AIMET work. I added more comments to some individual changes in code. If you have more questions about some other specific changes let me know.

self.b_cross_entropy = BCEWithLogitsLoss(
pos_weight=torch.tensor([viz_pw])
)
self.pos_weight = torch.tensor([viz_pw])
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping a reference to another BaseLoss that is not attached to a specific node is problematic when copying the model. In this case it's easier to use F.binary_cross_entropy_with_logits directly instead of our BSEWithLogitsLoss.

adj_kpts[..., 1] += y_adj
return adj_kpts

def _init_parameters(self, features: list[Tensor]) -> None:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameters created this way were causing weird errors during quantization because the tensors were created during inference_mode.

(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.dtype == img1.dtype:
window = self.window.to(device)
window = self.window.to(device).clone()
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cloning is another way how to fix the "tensors created in inference mode" issue.

"""
return super().compute()

def __eq__(self, other: object) -> bool:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes an issue with fold_all_batch_norms.

torchmetrics.Metric supports chaining individual metrics into larger pipelines using overloaded math operators.

For example you can do:

from torchmetrics import Precision, Recall

precision = Precision(task="binary")
recall = Recall(task="binary")

# Operator overloading on the classes
f1_score = 2 * (precision * recall) / (precision + recall)

The f1_score works the same as the official torchmetrics.F1Score, but it was created by pipelining smaller metrics.

from torchmetrics import F1Score

official_f1_score = F1Score(task="binary")

f1_score.update(
    tensor([0, 1, 1, 0, 0, 1]), tensor([0, 1, 0, 1, 1, 1])
)
print(f"F1 Score: {f1_score.compute()}")
# F1 Score: 0.5714

official_f1_score.update(
    tensor([0, 1, 1, 0, 0, 1]), tensor([0, 1, 0, 1, 1, 1])
)
print(f"Official F1 Score: {official_f1_score.compute()}")
# Official F1 Score: 0.5714

print(f1_score)
# CompositionalMetric(
#   true_divide(
#     CompositionalMetric(
#       mul(
#         2,
#         CompositionalMetric(
#           mul(
#             BinaryPrecision(),
#             BinaryRecall()
#           )
#         )
#       )
#     ),
#     CompositionalMetric(
#       add(
#         BinaryPrecision(),
#         BinaryRecall()
#       )
#     )
#   )
# )

This is cool but it works for == and != as well and for combinations of torchmetrics.Metric and other types:

foo = precision != 2
print(foo)
# CompositionalMetric(
#   ne(
#     BinaryPrecision(),
#     2
#   )
# )

This has one major disadvantage:

foo = precision != 2
print(bool(foo))
# True

if precision == None:
    print("This should not happen")
# This should not happen

This breaks AIMET which expects comparisons to work, so in order to fix it we have to re-implement __eq__ (and __hash__) ourselves.

super().__init__(*args, **kwargs)
self.scale = scale

@override
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ColorMap internally uses a generator so it cannot be pickled. Before pickling we remove it from the instance state.


@override
def forward(
self, inputs: dict[str, Tensor] | Tensor
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AIMET expects a simple forward that expects a regular nn.Tensor.

"artifacts": sorted(artifact_keys),
}

@override
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer and core cannot be pickled which causes the quantization to fail. We could just delete them before pickling but we do need to actually remember their states. This hack makes pickling and unpickling work as long as there is only one instance of LuxonisLightning which is good enough for the quantization to succeed.

super().__init__()
self.name = name
self.module = module
self.losses = _to_module_dict(losses)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the attached modules registered as modules of the overall LuxonisLightning module caused issues during quantization. After this change the attached modules are no longer considered a part of the core model itself.

"`augment_test_image` method to expose this functionality."
)

def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AIMET expects dataloaders to have a regular __getitem__ that returns a tuple with a simple torch.Tensor and labels.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 21, 2026

Warning

Rate limit exceeded

@kozlov721 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 41 minutes and 46 seconds before requesting another review.

You’ve run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 13262b04-dfb1-4df4-9b20-e03992c2e3b7

📥 Commits

Reviewing files that changed from the base of the PR and between 25656ab and 70faf84.

📒 Files selected for processing (7)
  • configs/README.md
  • luxonis_train/__main__.py
  • luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
  • luxonis_train/callbacks/README.md
  • luxonis_train/callbacks/luxonis_progress_bar.py
  • luxonis_train/core/core.py
  • luxonis_train/utils/dataset_metadata.py
📝 Walkthrough

Walkthrough

Adds AIMET quantization support (config, utilities, LuxonisModel.quantize, CLI, callback) and refactors loader/forward interfaces to accept dict or tensor inputs; updates Lightning export/pickling, loss/visualizer/metric modules, packaging, CI, docs, and tests.

Changes

AIMET Quantization Integration

Layer / File(s) Summary
Loader API Refactoring: get() → getitem()
luxonis_train/loaders/base_loader.py, luxonis_train/loaders/dummy_loader.py, luxonis_train/loaders/luxonis_loader_torch.py, luxonis_train/loaders/luxonis_perlin_loader_torch.py, tests/unittests/test_loaders/test_base_loader.py, tests/integration/test_custom_model.py, media/anomaly_detection_diagram.drawio
Base loader contract changed from optional get() to required __getitem__(). Type now permits both Tensor and dict[str, Tensor] inputs. All concrete loaders and test loaders updated; diagram and documentation aligned.
Input Format Flexibility: Dict & Tensor Support
luxonis_train/__main__.py, luxonis_train/core/utils/infer_utils.py, luxonis_train/core/utils/annotate_utils.py, luxonis_train/callbacks/gradcam_visualizer.py, luxonis_train/loaders/base_loader.py
Forward/inference paths now accept both dict[str, Tensor] and raw Tensor inputs. Normalizes images to loader.image_source dict key; handles both formats in inference/annotation utilities and callbacks. Collate function branches on input type; visualization item wrapping normalizes non-dict images.
Lightning Module Core: Forward, Training, Export
luxonis_train/lightning/luxonis_lightning.py
Refactors LuxonisLightningModule to expose tensor-only forward interface returning tuple[Tensor, ...]. Normalizes full_forward inputs/device placement. Adds fluent train()/set_export_mode()/reparametrize() methods returning Self. Improves ONNX export via new _get_output_onnx_names() helper and corrects device handling. Adds detach() for trainer cleanup.
Lightning Module: Pickling & State Management
luxonis_train/lightning/luxonis_lightning.py
Adds custom __getstate__/__setstate__ for trainer/core via _INTERNAL registry. Updates validation_step/test_step/predict_step batch typing to accept dict|Tensor with proper device/denormalization handling.
AIMET Configuration Models
luxonis_train/config/config.py
Introduces AdaroundConfig and AIMETConfig Pydantic models with defaults, field validation (enforces required fields when active, loads JSON config from file path), enum serialization. Extends ExportConfig with aimet field.
AIMET Quantization Utilities
luxonis_train/core/utils/aimet_utils.py
Implements check_aimet_available(), post_training_quantization() (model/input CUDA movement, batch-norm folding, cross-layer equalization, AdaRound, sequential MSE, calibration), and quantization_aware_training() (manual training loop with encoding management).
LuxonisModel.quantize() Method & Infrastructure
luxonis_train/core/core.py
Adds public quantize() method coordinating AIMET: resolves hyperparameters from config/args, builds dummy inputs, deep-copies model when needed, runs pre/PTQ/QAT validation, exports ONNX, prints results. Adds typed train_loader/val_loader/test_loader properties. Fixes test(new_thread=True) to return Thread instead of None.
CLI Quantize Command & Visualization Normalization
luxonis_train/__main__.py
Adds quantize CLI command (export_group sort_key=1) calling model.quantize(). Updates _yield_visualizations to normalize loader images to {loader.image_source: images} dict.
AIMETCallback for Post-Training Quantization
luxonis_train/callbacks/aimet_callback.py, luxonis_train/callbacks/__init__.py, luxonis_train/lightning/utils.py
Introduces AIMETCallback subclassing NeedsCheckpoint, executing pl_module.core.quantize() on on_train_end. Exported publicly and integrated when exporter.aimet.active=True.
Loss Module Buffer Registration
luxonis_train/attached_modules/losses/adaptive_detection_loss.py, luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py
Registers scale/rate buffers (gt_bboxes_scale, gt_kpts_scale) in __init__ via register_buffer(persistent=False) instead of lazy initialization. Updates _init_parameters guards and stride cloning for AIMET compatibility.
Visualizer Colormap Centralization
luxonis_train/attached_modules/visualizers/base_visualizer.py, luxonis_train/attached_modules/visualizers/embeddings_visualizer.py, luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
Moves ColorMap from subclasses to BaseVisualizer as @cached_property. Adds __getstate__ to exclude colormap from pickling. Subclasses now inherit via self.colormap. Includes NaN/Inf handling and empty-array axis-limit guard.
BaseMetric Identity-Based Equality & Hashing
luxonis_train/attached_modules/metrics/base_metric.py
Adds __eq__ and __hash__ for identity-based comparison/hashing to support set/dict operations on metric instances.
DetectionConfusionMatrix Inference Mode Safety
luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
Clones confusion_matrix before in-place operations when in inference mode to prevent mutation during validation.
Progress Bar Table Rendering Refactor
luxonis_train/callbacks/luxonis_progress_bar.py
Standardizes table rendering behind new abstract print_table() method accepting iterable of row tuples and column_names. Updates TQDM (tabulate-based) and Rich implementations with dynamic column formatting (5-decimal float precision for Rich).
Block Reparametrization Error Handling
luxonis_train/nodes/blocks/blocks.py
Changes GeneralReparametrizableBlock.reparametrize() and restore() to return early instead of raising RuntimeError, enabling idempotent calls.
Torch Hub Trust Flags
luxonis_train/nodes/backbones/dinov3/dinov3.py, luxonis_train/nodes/backbones/efficientnet.py
Adds trust_repo=True to torch.hub.load() calls for DINOv3 and EfficientNet backbone loading.
AIMET-Compatible Quantized Blocks
luxonis_train/nodes/backbones/pplcnet_v3/blocks.py, luxonis_train/nodes/blocks/__init__.py
Adds conditional QuantizedAffineBlock guarded by suppress(ImportError) when aimet_torch available. Integrates QuantizationMixin with __quant_init__ and forward override. Configures mixin to ignore DropPath/UpscaleOnline.
NodeWrapper Fluent Training & Flexibility
luxonis_train/lightning/utils.py
Adds NodeWrapper.train() propagating mode to node and attached modules, returning Self. Updates LossAccumulator.__init__ to accept *args/**kwargs. Removes _to_module_dict() helper.
Minor Improvements
luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py, luxonis_train/config/predefined_models/base_predefined_model.py, luxonis_train/utils/dataset_metadata.py, luxonis_train/registry.py, luxonis_train/utils/__init__.py
SSIM window cloning for device moves; main_metric selection guarded by metrics presence; DatasetMetadata error messages enhanced; _INTERNAL registry added; utility export deduplication.
AIMET Documentation & CLI
README.md, configs/README.md, luxonis_train/callbacks/README.md
Adds AIMET installation section with [aimet] extra and CUDA index URL. Documents quantize CLI command. Adds AIMET exporter options and detailed AIMET/Adaround configuration sections. Documents AIMETCallback.
Packaging & Dependencies
pyproject.toml, requirements-aimet.txt, requirements.txt, .github/workflows/ci.yaml
Adds aimet optional extra to pyproject.toml. Creates requirements-aimet.txt with aimet-torch~=2.31, torch==2.11, torchvision~=0.26. Updates CI workflows to install [dev,aimet] with CUDA 13.0 index. Removes torch<2.11 constraint.
Test Fixtures & Integration Tests
tests/conftest.py, tests/integration/test_predefined_models.py, tests/integration/test_callbacks.py, tests/integration/backbone_model_utils.py, tests/integration/test_combinations.py, tests/unittests/test_utils/test_dataset_metadata.py
Adds exporter.aimet config to opts fixture. Extends test_predefined_models with quantize subtest. Disables AIMET in test_callbacks (determinism conflict). Updates backbone/test utilities. Fixes metadata test assertions.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

🐰 Whisker-twitching through quantized dreams,
We bundle Affine blocks with AIMET schemes,
From dicts to tensors, the loaders now flex,
ONNX exports what Lightning's next,
With callbacks aplenty and buffers reborn.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main objective of this PR: adding AIMET integration for quantization support to the project.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feature/aimet

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
luxonis_train/callbacks/luxonis_progress_bar.py (2)

323-365: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update docstring to match the current signature.

The docstring at lines 335-341 references old parameters (key_name, value_name) that are no longer part of the method signature. Update the documentation to reflect the new table and column_names parameters.

📝 Proposed docstring fix
     `@override`
     def print_table(
         self,
         title: str,
         table: Iterable[tuple[str | int | float, ...]],
         column_names: list[str],
         console: Console | None = None,
     ) -> None:
         """Prints table to the console using rich text.
 
         `@type` title: str
         `@param` title: Title of the table
-        `@type` table: Mapping[str, int | str | float]
-        `@param` table: Table to print
-        `@type` key_name: str
-        `@param` key_name: Name of the key column. Defaults to C{"Name"}.
-        `@type` value_name: str
-        `@param` value_name: Name of the value column. Defaults to
-            C{"Value"}.
-        `@param` console: Console instance to use, if None use default
-            console. Defaults to None.
+        `@type` table: Iterable[tuple[str | int | float, ...]]
+        `@param` table: Iterable of row tuples, where each tuple contains
+            the values for one row
+        `@type` column_names: list[str]
+        `@param` column_names: Names of the columns in the table
         `@type` console: Console | None
+        `@param` console: Console instance to use, if None use default
+            console. Defaults to None.
         """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 323 - 365, The
docstring for print_table is out of date: remove references to old
key_name/value_name parameters and update descriptions to document the current
signature parameters (title: str, table: Iterable[tuple[str|int|float, ...]],
column_names: list[str], console: Console | None). In the print_table docstring
mention that table is an iterable of rows where the first element is the row
key/name and the rest are column values, and that column_names provides the
header labels; keep the console description and types in the docstring
consistent with the function signature and typing used in the method.

173-199: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update docstring to match the current signature.

The docstring at lines 184-190 references old parameters (key_name, value_name) that no longer exist in the method signature. Update the documentation to reflect the new table and column_names parameters.

📝 Proposed docstring fix
     `@override`
     def print_table(
         self,
         title: str,
         table: Iterable[tuple[str | int | float, ...]],
         column_names: list[str],
     ) -> None:
         """Prints table to the console using tabulate.
 
         `@type` title: str
         `@param` title: Title of the table
-        `@type` table: Mapping[str, int | str | float]
-        `@param` table: Table to print
-        `@type` key_name: str
-        `@param` key_name: Name of the key column. Defaults to C{"Name"}.
-        `@type` value_name: str
-        `@param` value_name: Name of the value column. Defaults to
-            C{"Value"}.
+        `@type` table: Iterable[tuple[str | int | float, ...]]
+        `@param` table: Iterable of row tuples, where each tuple contains
+            the values for one row
+        `@type` column_names: list[str]
+        `@param` column_names: Names of the columns in the table
         """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 173 - 199, The
print_table method's docstring is outdated: it mentions removed parameters
key_name and value_name; update the docstring on print_table to describe the
current signature (title: str, table: Iterable[tuple[str|int|float, ...]>,
column_names: list[str]), explain what table and column_names represent and any
formatting or return behavior, and remove references to key_name/value_name;
locate the method by name print_table in luxonis_progress_bar.py and adjust the
docstring text accordingly to match parameters and types used in the function.
luxonis_train/utils/dataset_metadata.py (1)

119-133: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Docstring is inconsistent with the new behavior.

The docstring at lines 126-127 states @raises ValueError: If the C{task} is not present in the dataset, but the implementation now returns 0 for unknown task names instead of raising.

📝 Proposed fix
     def n_keypoints(self, task_name: str | None = None) -> int:
         """Gets the number of keypoints for the specified task.
 
         `@type` task_name: str | None
         `@param` task_name: Task to get the number of keypoints for.
         `@rtype`: int
-        `@return`: Number of keypoints for the specified task type.
-        `@raises` ValueError: If the C{task} is not present in the
-            dataset.
+        `@return`: Number of keypoints for the specified task type,
+            or 0 if the task is not present in the dataset.
         `@raises` RuntimeError: If the C{task} was not provided and the
             dataset contains different number of keypoints for different
             task types.
         """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/utils/dataset_metadata.py` around lines 119 - 133, The
docstring promises a ValueError when a requested task is not present, but
n_keypoints currently returns 0 for unknown task names; update n_keypoints to
raise ValueError when task_name is provided but not found in self._n_keypoints
(use the method name n_keypoints and the attribute self._n_keypoints to locate
the code) so behavior matches the docstring, or alternatively update the
docstring to state that unknown tasks return 0—prefer changing the
implementation to raise ValueError to preserve the documented contract.
🧹 Nitpick comments (5)
luxonis_train/callbacks/aimet_callback.py (1)

13-16: 💤 Low value

Consider adding logging to indicate quantization progress.

The callback silently calls quantize() at training end. Adding a log message would help users understand what's happening, especially since AIMET quantization can be time-consuming (per PR discussion, Adaround on a toy COCO model took ~40 minutes).

💡 Suggested improvement
+from loguru import logger
+
 `@CALLBACKS.register`()
 class AIMETCallback(NeedsCheckpoint):
     def __init__(self, **kwargs):
         super().__init__(**kwargs)
 
     def on_train_end(
         self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule"
     ) -> None:
+        logger.info("Starting AIMET quantization...")
         pl_module.core.quantize(self.get_checkpoint(pl_module))
+        logger.info("AIMET quantization complete.")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/callbacks/aimet_callback.py` around lines 13 - 16, Add
informational logging around the AIMET quantization call so users know when
quantization starts, is in progress, and finishes (or fails): in the callback
method on_train_end, before calling
pl_module.core.quantize(self.get_checkpoint(pl_module)) emit a log/info message
that quantization is starting (include model/checkpoint identification via
get_checkpoint(pl_module)), optionally log progress checkpoints if available
from LuxonisLightningModule.core.quantize, and after the call emit a completion
log or catch exceptions to log an error; reference on_train_end, get_checkpoint,
and the LuxonisLightningModule.core.quantize call when making the changes.
luxonis_train/lightning/utils.py (1)

56-58: 💤 Low value

Unused *args, **kwargs in LossAccumulator.__init__.

The signature accepts *args, **kwargs but they are silently discarded—only float is passed to the parent defaultdict. If this was added for pickle/deepcopy compatibility, consider documenting that intent. If callers might mistakenly pass arguments expecting them to be used, this could mask errors.

💡 Suggested documentation
 class LossAccumulator(defaultdict[str, float]):
     def __init__(self, *args, **kwargs):
+        # args/kwargs accepted for pickle/deepcopy compatibility but not used;
+        # factory is always `float`.
         super().__init__(float)
         self.counts = defaultdict(int)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/lightning/utils.py` around lines 56 - 58, The __init__ of
LossAccumulator accepts unused *args and **kwargs which are silently discarded;
change the signature to def __init__(self): (remove *args, **kwargs) and keep
the body calling super().__init__(float) and self.counts = defaultdict(int);
this removes the misleading API surface (or alternatively, if those params were
intentionally for pickling, explicitly document that intent and forward them to
super via super().__init__(float, *args, **kwargs)). Ensure you update
LossAccumulator.__init__, the call to super().__init__(float), and the
initialization of counts accordingly.
luxonis_train/callbacks/README.md (1)

189-192: ⚡ Quick win

Expand AIMETCallback documentation to match other callbacks.

The AIMETCallback section is significantly less detailed than other callback documentation in this file. Consider adding:

  • A parameters table (or note that it has no direct parameters)
  • A reference to the exporter.aimet configuration section where users control quantization behavior
  • Default behavior note (e.g., "Added automatically when exporter.aimet.active=True")
  • Brief explanation of what quantization techniques are applied

Compare this to TestOnTrainEnd (lines 146-157) or ConvertOnTrainEnd (lines 113-133) for structure examples.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/callbacks/README.md` around lines 189 - 192, Update the
AIMETCallback documentation block to match the structure of other callbacks: add
a brief parameters table (or explicitly state "no direct parameters"), add a
line referencing the exporter.aimet configuration section (e.g., "Controlled via
exporter.aimet.* settings"), state the default behavior (for example
"Automatically added when exporter.aimet.active=True"), and include a short
summary of the quantization techniques applied (e.g., post-training
quantization, per-channel bias/scale folding, and calibration method used).
Ensure you reference the AIMETCallback symbol and mirror the style used in
TestOnTrainEnd/ConvertOnTrainEnd sections for consistency.
luxonis_train/callbacks/luxonis_progress_bar.py (1)

62-78: ⚡ Quick win

Update docstring to reflect the new tuple-based signature.

The docstring at line 73 doesn't clearly describe the new structure. Consider updating it to explain that table is an iterable of tuples where each tuple represents a row (e.g., (name, value1, value2, ...)), and column_names provides the header labels.

📝 Proposed docstring improvement
     `@abstractmethod`
     def print_table(
         self,
         title: str,
         table: Iterable[tuple[str | int | float, ...]],
         column_names: list[str],
     ) -> None:
         """Prints table to the console.
 
         `@type` title: str
         `@param` title: Title of the table
-        `@type` table: Mapping[str, int | str | float]
-        `@param` table: Table to print
+        `@type` table: Iterable[tuple[str | int | float, ...]]
+        `@param` table: Iterable of row tuples, where each tuple contains
+            the values for one row (e.g., (name, value1, value2, ...))
         `@type` column_names: list[str]
         `@param` column_names: Names of the columns in the table
         """
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 62 - 78, Update
the print_table docstring in the abstractmethod print_table to describe the new
tuple-based signature: state that table is an Iterable of tuples where each
tuple is a row (e.g., ("row_label", value1, value2, ...)) and that column_names
is a list of header labels mapping to tuple positions; include a short example
and clarify accepted element types (str | int | float) and that tuple lengths
should match len(column_names). Ensure the description references the
print_table method and its parameters (title, table, column_names).
luxonis_train/__main__.py (1)

485-497: ⚡ Quick win

Add docstring for CLI help consistency.

The quantize command lacks a docstring, unlike all other commands (train, tune, test, export, etc.). This means users won't see help text when running luxonis_train quantize --help.

📝 Proposed fix
 `@app.command`(group=export_group, sort_key=1)
 def quantize(
     opts: list[str] | None = None,
     /,
     *,
     config: str | None = None,
     weights: str | None = None,
 ):
+    """Quantize the model using AIMET.
+
+    `@type` config: str
+    `@param` config: Path to the configuration file or a name of a
+        predefined model.
+    `@type` weights: str
+    `@param` weights: Path to the model weights.
+    `@type` opts: list[str]
+    `@param` opts: A list of optional CLI overrides of the config file.
+    """
     model = create_model(
         config, opts, weights=weights, allow_empty_dataset=True
     )
     model.quantize()
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@luxonis_train/__main__.py` around lines 485 - 497, The quantize CLI command
is missing a docstring so it doesn't show help text; add a short descriptive
docstring immediately under the quantize function definition (the function named
quantize that calls create_model(...) and model.quantize()) describing what the
command does and its parameters (opts, config, weights) and any important
behavior (e.g., allow_empty_dataset=True), so the CLI help (luxonis_train
quantize --help) displays meaningful information.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@configs/README.md`:
- Line 513: Fix the malformed markdown link in the table row for `aimet`: locate
the cell containing the link text `](`#aimet`)\]` (adjacent to the `aimet` table
entry) and remove the extraneous `\]` so the link becomes `](`#aimet`)`,
preserving the rest of the row content.

In
`@luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py`:
- Around line 61-62: The per-image cloning of self.confusion_matrix inside the
loop is redundant and ineffective: move the clone logic out of the per-image
loop in the _update method so you clone at most once per _update call (and do it
outside any torch.inference_mode() context if you need a non-inference tensor).
Concretely, in the _update implementation check
self.confusion_matrix.is_inference() once before iterating over images and, if
needed, replace self.confusion_matrix = self.confusion_matrix.clone() there (or
perform the clone after leaving inference mode) so you avoid repeated
allocations and preserve intended mutation semantics.

In `@luxonis_train/core/core.py`:
- Around line 1461-1465: The call site that builds the adaround arguments
(passing adaround_reg_param, adaround_warm_start, adaround_iterations,
adaround_beta_range, etc.) and the place selecting epochs uses truthy "or"
fallbacks which incorrectly override valid numeric zeros; change each fallback
from the pattern "value or cfg.foo" to an explicit None check such as "value if
value is not None else cfg.foo" (apply this to adaround_reg_param,
adaround_warm_start, adaround_iterations, adaround_beta_range and the epochs
selection) so that 0.0 or 0 are preserved but None still falls back to the
config defaults.
- Around line 1430-1434: The baseline pre-quantization test is being run before
custom weights are applied; move the weights-loading so
model.load_checkpoint(weights) executes before computing pre_quant_test: ensure
you call model.load_checkpoint(weights) (when weights is not None) prior to
invoking model.reparametrize().eval() and self.pl_trainer.test(...) so
pre_quant_test reflects the provided weights.

In `@luxonis_train/core/utils/aimet_utils.py`:
- Around line 58-67: pass_calibration_data currently runs batches from
val_loader on the CPU, causing a device mismatch when the model is on CUDA;
update the function so it determines the model device (e.g., device =
next(model.parameters()).device), moves imgs to that device before invoking
model.forward (imgs = imgs.to(device)), and run the loop under torch.no_grad()
to avoid building gradients.
- Around line 159-169: The QAT training loop fails to move batch tensors to the
device before forwarding; in the loop that iterates over train_loader and calls
model.training_step((imgs, labels)) you must transfer imgs and labels to the
same device as the model (e.g., device = next(model.parameters()).device) before
computing loss/backward/optimizer.step; update the loop in the block using
track(...) so inputs and targets are moved to device (and ensure any
dtype/device-specific handling for labels is preserved) prior to calling
model.training_step, optimizer.zero_grad, loss.backward, and optimizer.step.
- Around line 86-107: Guard against val_loader.batch_size being None before
using it to compute num_batches: when building the AdaroundParameters
(AdaroundParameters(...) in this block) compute num_batches using a safe
expression that checks val_loader.batch_size and falls back to using
len(val_loader) (or a sensible default batch size) so math.ceil(2000 /
val_loader.batch_size) is only executed when batch_size is not None; keep
default_num_iterations set to adaround_iterations (no forced fallback) and pass
the safe ada_params into Adaround.apply_adaround as before.

---

Outside diff comments:
In `@luxonis_train/callbacks/luxonis_progress_bar.py`:
- Around line 323-365: The docstring for print_table is out of date: remove
references to old key_name/value_name parameters and update descriptions to
document the current signature parameters (title: str, table:
Iterable[tuple[str|int|float, ...]], column_names: list[str], console: Console |
None). In the print_table docstring mention that table is an iterable of rows
where the first element is the row key/name and the rest are column values, and
that column_names provides the header labels; keep the console description and
types in the docstring consistent with the function signature and typing used in
the method.
- Around line 173-199: The print_table method's docstring is outdated: it
mentions removed parameters key_name and value_name; update the docstring on
print_table to describe the current signature (title: str, table:
Iterable[tuple[str|int|float, ...]>, column_names: list[str]), explain what
table and column_names represent and any formatting or return behavior, and
remove references to key_name/value_name; locate the method by name print_table
in luxonis_progress_bar.py and adjust the docstring text accordingly to match
parameters and types used in the function.

In `@luxonis_train/utils/dataset_metadata.py`:
- Around line 119-133: The docstring promises a ValueError when a requested task
is not present, but n_keypoints currently returns 0 for unknown task names;
update n_keypoints to raise ValueError when task_name is provided but not found
in self._n_keypoints (use the method name n_keypoints and the attribute
self._n_keypoints to locate the code) so behavior matches the docstring, or
alternatively update the docstring to state that unknown tasks return 0—prefer
changing the implementation to raise ValueError to preserve the documented
contract.

---

Nitpick comments:
In `@luxonis_train/__main__.py`:
- Around line 485-497: The quantize CLI command is missing a docstring so it
doesn't show help text; add a short descriptive docstring immediately under the
quantize function definition (the function named quantize that calls
create_model(...) and model.quantize()) describing what the command does and its
parameters (opts, config, weights) and any important behavior (e.g.,
allow_empty_dataset=True), so the CLI help (luxonis_train quantize --help)
displays meaningful information.

In `@luxonis_train/callbacks/aimet_callback.py`:
- Around line 13-16: Add informational logging around the AIMET quantization
call so users know when quantization starts, is in progress, and finishes (or
fails): in the callback method on_train_end, before calling
pl_module.core.quantize(self.get_checkpoint(pl_module)) emit a log/info message
that quantization is starting (include model/checkpoint identification via
get_checkpoint(pl_module)), optionally log progress checkpoints if available
from LuxonisLightningModule.core.quantize, and after the call emit a completion
log or catch exceptions to log an error; reference on_train_end, get_checkpoint,
and the LuxonisLightningModule.core.quantize call when making the changes.

In `@luxonis_train/callbacks/luxonis_progress_bar.py`:
- Around line 62-78: Update the print_table docstring in the abstractmethod
print_table to describe the new tuple-based signature: state that table is an
Iterable of tuples where each tuple is a row (e.g., ("row_label", value1,
value2, ...)) and that column_names is a list of header labels mapping to tuple
positions; include a short example and clarify accepted element types (str | int
| float) and that tuple lengths should match len(column_names). Ensure the
description references the print_table method and its parameters (title, table,
column_names).

In `@luxonis_train/callbacks/README.md`:
- Around line 189-192: Update the AIMETCallback documentation block to match the
structure of other callbacks: add a brief parameters table (or explicitly state
"no direct parameters"), add a line referencing the exporter.aimet configuration
section (e.g., "Controlled via exporter.aimet.* settings"), state the default
behavior (for example "Automatically added when exporter.aimet.active=True"),
and include a short summary of the quantization techniques applied (e.g.,
post-training quantization, per-channel bias/scale folding, and calibration
method used). Ensure you reference the AIMETCallback symbol and mirror the style
used in TestOnTrainEnd/ConvertOnTrainEnd sections for consistency.

In `@luxonis_train/lightning/utils.py`:
- Around line 56-58: The __init__ of LossAccumulator accepts unused *args and
**kwargs which are silently discarded; change the signature to def
__init__(self): (remove *args, **kwargs) and keep the body calling
super().__init__(float) and self.counts = defaultdict(int); this removes the
misleading API surface (or alternatively, if those params were intentionally for
pickling, explicitly document that intent and forward them to super via
super().__init__(float, *args, **kwargs)). Ensure you update
LossAccumulator.__init__, the call to super().__init__(float), and the
initialization of counts accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: eacfc6d3-7d1d-4fc3-afdf-f843d8e3f039

📥 Commits

Reviewing files that changed from the base of the PR and between 720d4a4 and bb7f28c.

📒 Files selected for processing (49)
  • .github/workflows/ci.yaml
  • README.md
  • configs/README.md
  • luxonis_train/__main__.py
  • luxonis_train/attached_modules/losses/adaptive_detection_loss.py
  • luxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.py
  • luxonis_train/attached_modules/losses/reconstruction_segmentation_loss.py
  • luxonis_train/attached_modules/metrics/base_metric.py
  • luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py
  • luxonis_train/attached_modules/visualizers/base_visualizer.py
  • luxonis_train/attached_modules/visualizers/embeddings_visualizer.py
  • luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
  • luxonis_train/callbacks/README.md
  • luxonis_train/callbacks/__init__.py
  • luxonis_train/callbacks/aimet_callback.py
  • luxonis_train/callbacks/gradcam_visualizer.py
  • luxonis_train/callbacks/luxonis_progress_bar.py
  • luxonis_train/config/config.py
  • luxonis_train/config/predefined_models/base_predefined_model.py
  • luxonis_train/core/core.py
  • luxonis_train/core/utils/aimet_utils.py
  • luxonis_train/core/utils/annotate_utils.py
  • luxonis_train/core/utils/infer_utils.py
  • luxonis_train/lightning/luxonis_lightning.py
  • luxonis_train/lightning/utils.py
  • luxonis_train/loaders/base_loader.py
  • luxonis_train/loaders/dummy_loader.py
  • luxonis_train/loaders/luxonis_loader_torch.py
  • luxonis_train/loaders/luxonis_perlin_loader_torch.py
  • luxonis_train/nodes/backbones/dinov3/dinov3.py
  • luxonis_train/nodes/backbones/efficientnet.py
  • luxonis_train/nodes/backbones/pplcnet_v3/blocks.py
  • luxonis_train/nodes/blocks/__init__.py
  • luxonis_train/nodes/blocks/blocks.py
  • luxonis_train/registry.py
  • luxonis_train/utils/__init__.py
  • luxonis_train/utils/dataset_metadata.py
  • media/anomaly_detection_diagram.drawio
  • pyproject.toml
  • requirements-aimet.txt
  • requirements.txt
  • tests/conftest.py
  • tests/integration/backbone_model_utils.py
  • tests/integration/test_callbacks.py
  • tests/integration/test_combinations.py
  • tests/integration/test_custom_model.py
  • tests/integration/test_predefined_models.py
  • tests/unittests/test_loaders/test_base_loader.py
  • tests/unittests/test_utils/test_dataset_metadata.py
💤 Files with no reviewable changes (3)
  • luxonis_train/utils/init.py
  • requirements.txt
  • luxonis_train/attached_modules/visualizers/segmentation_visualizer.py

Comment thread configs/README.md Outdated
Comment thread luxonis_train/core/core.py
Comment thread luxonis_train/core/core.py Outdated
Comment thread luxonis_train/core/utils/aimet_utils.py
Comment thread luxonis_train/core/utils/aimet_utils.py
Comment thread luxonis_train/core/utils/aimet_utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLI Changes affecting the CLI documentation Improvements or additions to documentation enhancement New feature or request tests Adding or changing tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants