AIMET Integration#366
Conversation
9046d74 to
7a5d773
Compare
7a5d773 to
3fc8cb4
Compare
dtronmans
left a comment
There was a problem hiding this comment.
Just a few comments from reading it (did not test locally yet), could you also please include in the PR comment a few notes/justifications on the changes that are not necessarily related to the PR? This would be related to the changes to the losses, backbones, visualizers etc... Just to know what is in the scope of AIMET and what is not
| **cfg.scheduler.params, | ||
| ) | ||
|
|
||
| model = quantization_aware_training( |
There was a problem hiding this comment.
Is there a way to disable QAT and only run PTQ in the config? Right now it seems as though QAT is unconditional
There was a problem hiding this comment.
You can disable it by setting the number of epochs to 0.
| | `default_output_bw` | `int` | `8` | Default bitwidth for quantized activations and weights | | ||
| | `default_param_bw` | `int` | `8` | Default bitwidth for quantized parameters | | ||
| | `default_data_type` | `Literal["int", "float"]` | `int` | Default data type for quantized values | | ||
| | `quant_scheme` | `Literal["min_max", "post_training_tf_enhanced"]` | `min_max` | Quantization scheme to use | |
There was a problem hiding this comment.
The quant_scheme values mentioned in the config are quant_scheme: Literal["min_max", "tf", "tf_enhanced"] = "min_max"
There was a problem hiding this comment.
"min_max" and "tf" is the same, "tf" is a deprecated name so I'm not mentioning it in the README but I made the config still accept it.
| def validate_active(cls, data: Params) -> Params: | ||
| if not data.get("active", False): | ||
| return data | ||
| for required_field in [ |
There was a problem hiding this comment.
I think that for some parameters there is a mismatch between the validator, the config and the signature in the quantize method. Because fold_batch_norms in the config is by default False, but here it is checking whether or not it is set (even though it doesn't necessarily need to be set since there is a default value). And in the quantize() method the fold_batch_norms has default None in the signature
There was a problem hiding this comment.
This is intentional:
What if we just make this a required parameter is PTQ is active?
You mean the opt-in advanced techniques? Yeah I think this would make sense. What you want to use usually also depends on the specific model so as long as the entire AIMET is opt-in I think forcing the user to explicitly specify what to use is a good idea.
We decided that for the near future if you enable AIMET you have to explicitly enable or disable each technique you want to use.
kozlov721
left a comment
There was a problem hiding this comment.
could you also please include in the PR comment a few notes/justifications on the changes that are not necessarily related to the PR?
Mostly all the changes are somehow related to making AIMET work. I added more comments to some individual changes in code. If you have more questions about some other specific changes let me know.
| self.b_cross_entropy = BCEWithLogitsLoss( | ||
| pos_weight=torch.tensor([viz_pw]) | ||
| ) | ||
| self.pos_weight = torch.tensor([viz_pw]) |
There was a problem hiding this comment.
Keeping a reference to another BaseLoss that is not attached to a specific node is problematic when copying the model. In this case it's easier to use F.binary_cross_entropy_with_logits directly instead of our BSEWithLogitsLoss.
| adj_kpts[..., 1] += y_adj | ||
| return adj_kpts | ||
|
|
||
| def _init_parameters(self, features: list[Tensor]) -> None: |
There was a problem hiding this comment.
Parameters created this way were causing weird errors during quantization because the tensors were created during inference_mode.
| (_, channel, _, _) = img1.size() | ||
| if channel == self.channel and self.window.dtype == img1.dtype: | ||
| window = self.window.to(device) | ||
| window = self.window.to(device).clone() |
There was a problem hiding this comment.
Cloning is another way how to fix the "tensors created in inference mode" issue.
| """ | ||
| return super().compute() | ||
|
|
||
| def __eq__(self, other: object) -> bool: |
There was a problem hiding this comment.
This fixes an issue with fold_all_batch_norms.
torchmetrics.Metric supports chaining individual metrics into larger pipelines using overloaded math operators.
For example you can do:
from torchmetrics import Precision, Recall
precision = Precision(task="binary")
recall = Recall(task="binary")
# Operator overloading on the classes
f1_score = 2 * (precision * recall) / (precision + recall)The f1_score works the same as the official torchmetrics.F1Score, but it was created by pipelining smaller metrics.
from torchmetrics import F1Score
official_f1_score = F1Score(task="binary")
f1_score.update(
tensor([0, 1, 1, 0, 0, 1]), tensor([0, 1, 0, 1, 1, 1])
)
print(f"F1 Score: {f1_score.compute()}")
# F1 Score: 0.5714
official_f1_score.update(
tensor([0, 1, 1, 0, 0, 1]), tensor([0, 1, 0, 1, 1, 1])
)
print(f"Official F1 Score: {official_f1_score.compute()}")
# Official F1 Score: 0.5714
print(f1_score)
# CompositionalMetric(
# true_divide(
# CompositionalMetric(
# mul(
# 2,
# CompositionalMetric(
# mul(
# BinaryPrecision(),
# BinaryRecall()
# )
# )
# )
# ),
# CompositionalMetric(
# add(
# BinaryPrecision(),
# BinaryRecall()
# )
# )
# )
# )This is cool but it works for == and != as well and for combinations of torchmetrics.Metric and other types:
foo = precision != 2
print(foo)
# CompositionalMetric(
# ne(
# BinaryPrecision(),
# 2
# )
# )This has one major disadvantage:
foo = precision != 2
print(bool(foo))
# True
if precision == None:
print("This should not happen")
# This should not happenThis breaks AIMET which expects comparisons to work, so in order to fix it we have to re-implement __eq__ (and __hash__) ourselves.
| super().__init__(*args, **kwargs) | ||
| self.scale = scale | ||
|
|
||
| @override |
There was a problem hiding this comment.
ColorMap internally uses a generator so it cannot be pickled. Before pickling we remove it from the instance state.
|
|
||
| @override | ||
| def forward( | ||
| self, inputs: dict[str, Tensor] | Tensor |
There was a problem hiding this comment.
AIMET expects a simple forward that expects a regular nn.Tensor.
| "artifacts": sorted(artifact_keys), | ||
| } | ||
|
|
||
| @override |
There was a problem hiding this comment.
trainer and core cannot be pickled which causes the quantization to fail. We could just delete them before pickling but we do need to actually remember their states. This hack makes pickling and unpickling work as long as there is only one instance of LuxonisLightning which is good enough for the quantization to succeed.
| super().__init__() | ||
| self.name = name | ||
| self.module = module | ||
| self.losses = _to_module_dict(losses) |
There was a problem hiding this comment.
Having the attached modules registered as modules of the overall LuxonisLightning module caused issues during quantization. After this change the attached modules are no longer considered a part of the core model itself.
| "`augment_test_image` method to expose this functionality." | ||
| ) | ||
|
|
||
| def __getitem__(self, idx: int) -> LuxonisLoaderTorchOutput: |
There was a problem hiding this comment.
AIMET expects dataloaders to have a regular __getitem__ that returns a tuple with a simple torch.Tensor and labels.
|
Warning Rate limit exceeded
You’ve run out of usage credits. Purchase more in the billing tab. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. ℹ️ Review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (7)
📝 WalkthroughWalkthroughAdds AIMET quantization support (config, utilities, LuxonisModel.quantize, CLI, callback) and refactors loader/forward interfaces to accept dict or tensor inputs; updates Lightning export/pickling, loss/visualizer/metric modules, packaging, CI, docs, and tests. ChangesAIMET Quantization Integration
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
|
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
luxonis_train/callbacks/luxonis_progress_bar.py (2)
323-365:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate docstring to match the current signature.
The docstring at lines 335-341 references old parameters (
key_name,value_name) that are no longer part of the method signature. Update the documentation to reflect the newtableandcolumn_namesparameters.📝 Proposed docstring fix
`@override` def print_table( self, title: str, table: Iterable[tuple[str | int | float, ...]], column_names: list[str], console: Console | None = None, ) -> None: """Prints table to the console using rich text. `@type` title: str `@param` title: Title of the table - `@type` table: Mapping[str, int | str | float] - `@param` table: Table to print - `@type` key_name: str - `@param` key_name: Name of the key column. Defaults to C{"Name"}. - `@type` value_name: str - `@param` value_name: Name of the value column. Defaults to - C{"Value"}. - `@param` console: Console instance to use, if None use default - console. Defaults to None. + `@type` table: Iterable[tuple[str | int | float, ...]] + `@param` table: Iterable of row tuples, where each tuple contains + the values for one row + `@type` column_names: list[str] + `@param` column_names: Names of the columns in the table `@type` console: Console | None + `@param` console: Console instance to use, if None use default + console. Defaults to None. """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 323 - 365, The docstring for print_table is out of date: remove references to old key_name/value_name parameters and update descriptions to document the current signature parameters (title: str, table: Iterable[tuple[str|int|float, ...]], column_names: list[str], console: Console | None). In the print_table docstring mention that table is an iterable of rows where the first element is the row key/name and the rest are column values, and that column_names provides the header labels; keep the console description and types in the docstring consistent with the function signature and typing used in the method.
173-199:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winUpdate docstring to match the current signature.
The docstring at lines 184-190 references old parameters (
key_name,value_name) that no longer exist in the method signature. Update the documentation to reflect the newtableandcolumn_namesparameters.📝 Proposed docstring fix
`@override` def print_table( self, title: str, table: Iterable[tuple[str | int | float, ...]], column_names: list[str], ) -> None: """Prints table to the console using tabulate. `@type` title: str `@param` title: Title of the table - `@type` table: Mapping[str, int | str | float] - `@param` table: Table to print - `@type` key_name: str - `@param` key_name: Name of the key column. Defaults to C{"Name"}. - `@type` value_name: str - `@param` value_name: Name of the value column. Defaults to - C{"Value"}. + `@type` table: Iterable[tuple[str | int | float, ...]] + `@param` table: Iterable of row tuples, where each tuple contains + the values for one row + `@type` column_names: list[str] + `@param` column_names: Names of the columns in the table """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 173 - 199, The print_table method's docstring is outdated: it mentions removed parameters key_name and value_name; update the docstring on print_table to describe the current signature (title: str, table: Iterable[tuple[str|int|float, ...]>, column_names: list[str]), explain what table and column_names represent and any formatting or return behavior, and remove references to key_name/value_name; locate the method by name print_table in luxonis_progress_bar.py and adjust the docstring text accordingly to match parameters and types used in the function.luxonis_train/utils/dataset_metadata.py (1)
119-133:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winDocstring is inconsistent with the new behavior.
The docstring at lines 126-127 states
@raises ValueError: If the C{task} is not present in the dataset, but the implementation now returns0for unknown task names instead of raising.📝 Proposed fix
def n_keypoints(self, task_name: str | None = None) -> int: """Gets the number of keypoints for the specified task. `@type` task_name: str | None `@param` task_name: Task to get the number of keypoints for. `@rtype`: int - `@return`: Number of keypoints for the specified task type. - `@raises` ValueError: If the C{task} is not present in the - dataset. + `@return`: Number of keypoints for the specified task type, + or 0 if the task is not present in the dataset. `@raises` RuntimeError: If the C{task} was not provided and the dataset contains different number of keypoints for different task types. """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/utils/dataset_metadata.py` around lines 119 - 133, The docstring promises a ValueError when a requested task is not present, but n_keypoints currently returns 0 for unknown task names; update n_keypoints to raise ValueError when task_name is provided but not found in self._n_keypoints (use the method name n_keypoints and the attribute self._n_keypoints to locate the code) so behavior matches the docstring, or alternatively update the docstring to state that unknown tasks return 0—prefer changing the implementation to raise ValueError to preserve the documented contract.
🧹 Nitpick comments (5)
luxonis_train/callbacks/aimet_callback.py (1)
13-16: 💤 Low valueConsider adding logging to indicate quantization progress.
The callback silently calls
quantize()at training end. Adding a log message would help users understand what's happening, especially since AIMET quantization can be time-consuming (per PR discussion, Adaround on a toy COCO model took ~40 minutes).💡 Suggested improvement
+from loguru import logger + `@CALLBACKS.register`() class AIMETCallback(NeedsCheckpoint): def __init__(self, **kwargs): super().__init__(**kwargs) def on_train_end( self, _: pl.Trainer, pl_module: "lxt.LuxonisLightningModule" ) -> None: + logger.info("Starting AIMET quantization...") pl_module.core.quantize(self.get_checkpoint(pl_module)) + logger.info("AIMET quantization complete.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/callbacks/aimet_callback.py` around lines 13 - 16, Add informational logging around the AIMET quantization call so users know when quantization starts, is in progress, and finishes (or fails): in the callback method on_train_end, before calling pl_module.core.quantize(self.get_checkpoint(pl_module)) emit a log/info message that quantization is starting (include model/checkpoint identification via get_checkpoint(pl_module)), optionally log progress checkpoints if available from LuxonisLightningModule.core.quantize, and after the call emit a completion log or catch exceptions to log an error; reference on_train_end, get_checkpoint, and the LuxonisLightningModule.core.quantize call when making the changes.luxonis_train/lightning/utils.py (1)
56-58: 💤 Low valueUnused
*args, **kwargsinLossAccumulator.__init__.The signature accepts
*args, **kwargsbut they are silently discarded—onlyfloatis passed to the parentdefaultdict. If this was added for pickle/deepcopy compatibility, consider documenting that intent. If callers might mistakenly pass arguments expecting them to be used, this could mask errors.💡 Suggested documentation
class LossAccumulator(defaultdict[str, float]): def __init__(self, *args, **kwargs): + # args/kwargs accepted for pickle/deepcopy compatibility but not used; + # factory is always `float`. super().__init__(float) self.counts = defaultdict(int)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/lightning/utils.py` around lines 56 - 58, The __init__ of LossAccumulator accepts unused *args and **kwargs which are silently discarded; change the signature to def __init__(self): (remove *args, **kwargs) and keep the body calling super().__init__(float) and self.counts = defaultdict(int); this removes the misleading API surface (or alternatively, if those params were intentionally for pickling, explicitly document that intent and forward them to super via super().__init__(float, *args, **kwargs)). Ensure you update LossAccumulator.__init__, the call to super().__init__(float), and the initialization of counts accordingly.luxonis_train/callbacks/README.md (1)
189-192: ⚡ Quick winExpand AIMETCallback documentation to match other callbacks.
The AIMETCallback section is significantly less detailed than other callback documentation in this file. Consider adding:
- A parameters table (or note that it has no direct parameters)
- A reference to the
exporter.aimetconfiguration section where users control quantization behavior- Default behavior note (e.g., "Added automatically when
exporter.aimet.active=True")- Brief explanation of what quantization techniques are applied
Compare this to
TestOnTrainEnd(lines 146-157) orConvertOnTrainEnd(lines 113-133) for structure examples.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/callbacks/README.md` around lines 189 - 192, Update the AIMETCallback documentation block to match the structure of other callbacks: add a brief parameters table (or explicitly state "no direct parameters"), add a line referencing the exporter.aimet configuration section (e.g., "Controlled via exporter.aimet.* settings"), state the default behavior (for example "Automatically added when exporter.aimet.active=True"), and include a short summary of the quantization techniques applied (e.g., post-training quantization, per-channel bias/scale folding, and calibration method used). Ensure you reference the AIMETCallback symbol and mirror the style used in TestOnTrainEnd/ConvertOnTrainEnd sections for consistency.luxonis_train/callbacks/luxonis_progress_bar.py (1)
62-78: ⚡ Quick winUpdate docstring to reflect the new tuple-based signature.
The docstring at line 73 doesn't clearly describe the new structure. Consider updating it to explain that
tableis an iterable of tuples where each tuple represents a row (e.g.,(name, value1, value2, ...)), andcolumn_namesprovides the header labels.📝 Proposed docstring improvement
`@abstractmethod` def print_table( self, title: str, table: Iterable[tuple[str | int | float, ...]], column_names: list[str], ) -> None: """Prints table to the console. `@type` title: str `@param` title: Title of the table - `@type` table: Mapping[str, int | str | float] - `@param` table: Table to print + `@type` table: Iterable[tuple[str | int | float, ...]] + `@param` table: Iterable of row tuples, where each tuple contains + the values for one row (e.g., (name, value1, value2, ...)) `@type` column_names: list[str] `@param` column_names: Names of the columns in the table """🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/callbacks/luxonis_progress_bar.py` around lines 62 - 78, Update the print_table docstring in the abstractmethod print_table to describe the new tuple-based signature: state that table is an Iterable of tuples where each tuple is a row (e.g., ("row_label", value1, value2, ...)) and that column_names is a list of header labels mapping to tuple positions; include a short example and clarify accepted element types (str | int | float) and that tuple lengths should match len(column_names). Ensure the description references the print_table method and its parameters (title, table, column_names).luxonis_train/__main__.py (1)
485-497: ⚡ Quick winAdd docstring for CLI help consistency.
The
quantizecommand lacks a docstring, unlike all other commands (train,tune,test,export, etc.). This means users won't see help text when runningluxonis_train quantize --help.📝 Proposed fix
`@app.command`(group=export_group, sort_key=1) def quantize( opts: list[str] | None = None, /, *, config: str | None = None, weights: str | None = None, ): + """Quantize the model using AIMET. + + `@type` config: str + `@param` config: Path to the configuration file or a name of a + predefined model. + `@type` weights: str + `@param` weights: Path to the model weights. + `@type` opts: list[str] + `@param` opts: A list of optional CLI overrides of the config file. + """ model = create_model( config, opts, weights=weights, allow_empty_dataset=True ) model.quantize()🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@luxonis_train/__main__.py` around lines 485 - 497, The quantize CLI command is missing a docstring so it doesn't show help text; add a short descriptive docstring immediately under the quantize function definition (the function named quantize that calls create_model(...) and model.quantize()) describing what the command does and its parameters (opts, config, weights) and any important behavior (e.g., allow_empty_dataset=True), so the CLI help (luxonis_train quantize --help) displays meaningful information.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@configs/README.md`:
- Line 513: Fix the malformed markdown link in the table row for `aimet`: locate
the cell containing the link text `](`#aimet`)\]` (adjacent to the `aimet` table
entry) and remove the extraneous `\]` so the link becomes `](`#aimet`)`,
preserving the rest of the row content.
In
`@luxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.py`:
- Around line 61-62: The per-image cloning of self.confusion_matrix inside the
loop is redundant and ineffective: move the clone logic out of the per-image
loop in the _update method so you clone at most once per _update call (and do it
outside any torch.inference_mode() context if you need a non-inference tensor).
Concretely, in the _update implementation check
self.confusion_matrix.is_inference() once before iterating over images and, if
needed, replace self.confusion_matrix = self.confusion_matrix.clone() there (or
perform the clone after leaving inference mode) so you avoid repeated
allocations and preserve intended mutation semantics.
In `@luxonis_train/core/core.py`:
- Around line 1461-1465: The call site that builds the adaround arguments
(passing adaround_reg_param, adaround_warm_start, adaround_iterations,
adaround_beta_range, etc.) and the place selecting epochs uses truthy "or"
fallbacks which incorrectly override valid numeric zeros; change each fallback
from the pattern "value or cfg.foo" to an explicit None check such as "value if
value is not None else cfg.foo" (apply this to adaround_reg_param,
adaround_warm_start, adaround_iterations, adaround_beta_range and the epochs
selection) so that 0.0 or 0 are preserved but None still falls back to the
config defaults.
- Around line 1430-1434: The baseline pre-quantization test is being run before
custom weights are applied; move the weights-loading so
model.load_checkpoint(weights) executes before computing pre_quant_test: ensure
you call model.load_checkpoint(weights) (when weights is not None) prior to
invoking model.reparametrize().eval() and self.pl_trainer.test(...) so
pre_quant_test reflects the provided weights.
In `@luxonis_train/core/utils/aimet_utils.py`:
- Around line 58-67: pass_calibration_data currently runs batches from
val_loader on the CPU, causing a device mismatch when the model is on CUDA;
update the function so it determines the model device (e.g., device =
next(model.parameters()).device), moves imgs to that device before invoking
model.forward (imgs = imgs.to(device)), and run the loop under torch.no_grad()
to avoid building gradients.
- Around line 159-169: The QAT training loop fails to move batch tensors to the
device before forwarding; in the loop that iterates over train_loader and calls
model.training_step((imgs, labels)) you must transfer imgs and labels to the
same device as the model (e.g., device = next(model.parameters()).device) before
computing loss/backward/optimizer.step; update the loop in the block using
track(...) so inputs and targets are moved to device (and ensure any
dtype/device-specific handling for labels is preserved) prior to calling
model.training_step, optimizer.zero_grad, loss.backward, and optimizer.step.
- Around line 86-107: Guard against val_loader.batch_size being None before
using it to compute num_batches: when building the AdaroundParameters
(AdaroundParameters(...) in this block) compute num_batches using a safe
expression that checks val_loader.batch_size and falls back to using
len(val_loader) (or a sensible default batch size) so math.ceil(2000 /
val_loader.batch_size) is only executed when batch_size is not None; keep
default_num_iterations set to adaround_iterations (no forced fallback) and pass
the safe ada_params into Adaround.apply_adaround as before.
---
Outside diff comments:
In `@luxonis_train/callbacks/luxonis_progress_bar.py`:
- Around line 323-365: The docstring for print_table is out of date: remove
references to old key_name/value_name parameters and update descriptions to
document the current signature parameters (title: str, table:
Iterable[tuple[str|int|float, ...]], column_names: list[str], console: Console |
None). In the print_table docstring mention that table is an iterable of rows
where the first element is the row key/name and the rest are column values, and
that column_names provides the header labels; keep the console description and
types in the docstring consistent with the function signature and typing used in
the method.
- Around line 173-199: The print_table method's docstring is outdated: it
mentions removed parameters key_name and value_name; update the docstring on
print_table to describe the current signature (title: str, table:
Iterable[tuple[str|int|float, ...]>, column_names: list[str]), explain what
table and column_names represent and any formatting or return behavior, and
remove references to key_name/value_name; locate the method by name print_table
in luxonis_progress_bar.py and adjust the docstring text accordingly to match
parameters and types used in the function.
In `@luxonis_train/utils/dataset_metadata.py`:
- Around line 119-133: The docstring promises a ValueError when a requested task
is not present, but n_keypoints currently returns 0 for unknown task names;
update n_keypoints to raise ValueError when task_name is provided but not found
in self._n_keypoints (use the method name n_keypoints and the attribute
self._n_keypoints to locate the code) so behavior matches the docstring, or
alternatively update the docstring to state that unknown tasks return 0—prefer
changing the implementation to raise ValueError to preserve the documented
contract.
---
Nitpick comments:
In `@luxonis_train/__main__.py`:
- Around line 485-497: The quantize CLI command is missing a docstring so it
doesn't show help text; add a short descriptive docstring immediately under the
quantize function definition (the function named quantize that calls
create_model(...) and model.quantize()) describing what the command does and its
parameters (opts, config, weights) and any important behavior (e.g.,
allow_empty_dataset=True), so the CLI help (luxonis_train quantize --help)
displays meaningful information.
In `@luxonis_train/callbacks/aimet_callback.py`:
- Around line 13-16: Add informational logging around the AIMET quantization
call so users know when quantization starts, is in progress, and finishes (or
fails): in the callback method on_train_end, before calling
pl_module.core.quantize(self.get_checkpoint(pl_module)) emit a log/info message
that quantization is starting (include model/checkpoint identification via
get_checkpoint(pl_module)), optionally log progress checkpoints if available
from LuxonisLightningModule.core.quantize, and after the call emit a completion
log or catch exceptions to log an error; reference on_train_end, get_checkpoint,
and the LuxonisLightningModule.core.quantize call when making the changes.
In `@luxonis_train/callbacks/luxonis_progress_bar.py`:
- Around line 62-78: Update the print_table docstring in the abstractmethod
print_table to describe the new tuple-based signature: state that table is an
Iterable of tuples where each tuple is a row (e.g., ("row_label", value1,
value2, ...)) and that column_names is a list of header labels mapping to tuple
positions; include a short example and clarify accepted element types (str | int
| float) and that tuple lengths should match len(column_names). Ensure the
description references the print_table method and its parameters (title, table,
column_names).
In `@luxonis_train/callbacks/README.md`:
- Around line 189-192: Update the AIMETCallback documentation block to match the
structure of other callbacks: add a brief parameters table (or explicitly state
"no direct parameters"), add a line referencing the exporter.aimet configuration
section (e.g., "Controlled via exporter.aimet.* settings"), state the default
behavior (for example "Automatically added when exporter.aimet.active=True"),
and include a short summary of the quantization techniques applied (e.g.,
post-training quantization, per-channel bias/scale folding, and calibration
method used). Ensure you reference the AIMETCallback symbol and mirror the style
used in TestOnTrainEnd/ConvertOnTrainEnd sections for consistency.
In `@luxonis_train/lightning/utils.py`:
- Around line 56-58: The __init__ of LossAccumulator accepts unused *args and
**kwargs which are silently discarded; change the signature to def
__init__(self): (remove *args, **kwargs) and keep the body calling
super().__init__(float) and self.counts = defaultdict(int); this removes the
misleading API surface (or alternatively, if those params were intentionally for
pickling, explicitly document that intent and forward them to super via
super().__init__(float, *args, **kwargs)). Ensure you update
LossAccumulator.__init__, the call to super().__init__(float), and the
initialization of counts accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: eacfc6d3-7d1d-4fc3-afdf-f843d8e3f039
📒 Files selected for processing (49)
.github/workflows/ci.yamlREADME.mdconfigs/README.mdluxonis_train/__main__.pyluxonis_train/attached_modules/losses/adaptive_detection_loss.pyluxonis_train/attached_modules/losses/efficient_keypoint_bbox_loss.pyluxonis_train/attached_modules/losses/reconstruction_segmentation_loss.pyluxonis_train/attached_modules/metrics/base_metric.pyluxonis_train/attached_modules/metrics/confusion_matrix/detection_confusion_matrix.pyluxonis_train/attached_modules/visualizers/base_visualizer.pyluxonis_train/attached_modules/visualizers/embeddings_visualizer.pyluxonis_train/attached_modules/visualizers/segmentation_visualizer.pyluxonis_train/callbacks/README.mdluxonis_train/callbacks/__init__.pyluxonis_train/callbacks/aimet_callback.pyluxonis_train/callbacks/gradcam_visualizer.pyluxonis_train/callbacks/luxonis_progress_bar.pyluxonis_train/config/config.pyluxonis_train/config/predefined_models/base_predefined_model.pyluxonis_train/core/core.pyluxonis_train/core/utils/aimet_utils.pyluxonis_train/core/utils/annotate_utils.pyluxonis_train/core/utils/infer_utils.pyluxonis_train/lightning/luxonis_lightning.pyluxonis_train/lightning/utils.pyluxonis_train/loaders/base_loader.pyluxonis_train/loaders/dummy_loader.pyluxonis_train/loaders/luxonis_loader_torch.pyluxonis_train/loaders/luxonis_perlin_loader_torch.pyluxonis_train/nodes/backbones/dinov3/dinov3.pyluxonis_train/nodes/backbones/efficientnet.pyluxonis_train/nodes/backbones/pplcnet_v3/blocks.pyluxonis_train/nodes/blocks/__init__.pyluxonis_train/nodes/blocks/blocks.pyluxonis_train/registry.pyluxonis_train/utils/__init__.pyluxonis_train/utils/dataset_metadata.pymedia/anomaly_detection_diagram.drawiopyproject.tomlrequirements-aimet.txtrequirements.txttests/conftest.pytests/integration/backbone_model_utils.pytests/integration/test_callbacks.pytests/integration/test_combinations.pytests/integration/test_custom_model.pytests/integration/test_predefined_models.pytests/unittests/test_loaders/test_base_loader.pytests/unittests/test_utils/test_dataset_metadata.py
💤 Files with no reviewable changes (3)
- luxonis_train/utils/init.py
- requirements.txt
- luxonis_train/attached_modules/visualizers/segmentation_visualizer.py
Purpose
Adds option to quantize trained model using various PTQ and QAT techniques.
Specification
LuxonisModel.quantizemethodAIMETCallbackexporter.aimetto theConfigforwardmethod toLuxonixLightningModulefull_forwardprint_tableas a required abstract method toBaseLuxonisProgressBar__getstate__and__setstate__inLuxonisLightningModuleconfig.yamlDependencies & Potential Impact
None / not applicable
Deployment Plan
None / not applicable
Testing & Validation
AIMETCallbackintest_callbacksLuxonisModel.quantizefor all predefined models with the full set of PTQ techniques enabledSummary by CodeRabbit
New Features
Bug Fixes & Improvements
Documentation