Skip to content

Commit 03a2a7a

Browse files
WhiteSwan1claude
andcommitted
[refactor] SID: add ResidualQuantizer / BaseSidModel base classes
First of three PRs splitting the Semantic-ID models onto a shared base. Purely additive — only the backend-agnostic foundation, no concrete quantizer or model and no edits to existing files. RQ-KMeans follows in PR2, RQ-VAE in PR3. What this adds: - ResidualQuantizer (abstract): owns the shared state (embed_dim, per-layer codebook sizes via normalize_n_embed, residual-normalization flag, layer list; asserts n_layers >= 1) and the shared residual walk — _residual_pass drives the concrete get_codes / decode_codes / output_dim. Subclasses implement just _quantize_layer (encode) and _lookup_code (decode), plus forward and get_codebook_embeddings. - BaseSidModel (abstract): the shared SID model scaffold — embedding-feature extraction, loss/metric init (reconstruction MSE via torchmetrics.MeanSquaredError + codebook coverage via UniqueRatio), and shared config parsing — that SidRqkmeans / SidRqvae subclass. - UniqueRatio (tzrec/metrics/unique_ratio.py): codebook-coverage metric (mean per-batch unique-row ratio) with empty-batch guard + DDP reduction. Tests: normalize_n_embed; the abstract-base contract; the concrete residual walk via a fake one-primitive subclass; and the UniqueRatio metric. No proto changes and no edits to existing modules; __init__.py is a bare package marker (no re-exports). The QuantizeForwardMode / QuantizeOutput / ResidualQuantizerOutput types and the concrete SidRqkmeans / SidRqvae models ship with the code that uses them in PR2 / PR3. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 095f6d2 commit 03a2a7a

6 files changed

Lines changed: 580 additions & 0 deletions

File tree

tzrec/metrics/unique_ratio.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch
13+
from torchmetrics import Metric
14+
15+
16+
class UniqueRatio(Metric):
17+
"""Mean per-batch unique-SID ratio (distinct rows / batch size).
18+
19+
Averages, over batches, the fraction of distinct semantic-ID rows in each
20+
batch. It is a cheap (two-scalar state) **diversity proxy**, NOT global
21+
codebook coverage: a SID repeated across different batches counts as
22+
distinct in each, and smaller batches bias the value toward 1.0. Empty
23+
batches are skipped; the per-rank sums reduce by ``sum`` (a count-weighted
24+
mean).
25+
"""
26+
27+
higher_is_better = True
28+
is_differentiable = False
29+
30+
def __init__(self, **kwargs) -> None:
31+
super().__init__(**kwargs)
32+
self.add_state("ratio_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
33+
self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum")
34+
35+
def update(self, codes: torch.Tensor) -> None:
36+
"""Accumulate one batch's distinct-row ratio.
37+
38+
Args:
39+
codes (Tensor): semantic-ID codes, shape (B, n_layers).
40+
"""
41+
batch_size = codes.shape[0]
42+
if batch_size == 0:
43+
return
44+
unique = torch.unique(codes, dim=0).shape[0]
45+
self.ratio_sum += unique / batch_size
46+
self.count += 1
47+
48+
def compute(self) -> torch.Tensor:
49+
"""Mean per-batch unique ratio (NaN before any non-empty update)."""
50+
return self.ratio_sum / self.count

tzrec/metrics/unique_ratio_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
16+
from tzrec.metrics.unique_ratio import UniqueRatio
17+
18+
19+
class UniqueRatioTest(unittest.TestCase):
20+
def test_single_batch_ratio(self) -> None:
21+
metric = UniqueRatio()
22+
# 3 distinct rows out of 4 -> 0.75.
23+
metric.update(torch.tensor([[1, 2], [1, 2], [3, 4], [5, 6]]))
24+
self.assertAlmostEqual(metric.compute().item(), 0.75, places=6)
25+
26+
def test_mean_over_batches(self) -> None:
27+
metric = UniqueRatio()
28+
metric.update(torch.tensor([[1, 1], [1, 1]])) # 1/2 = 0.5
29+
metric.update(torch.tensor([[1, 1], [2, 2]])) # 2/2 = 1.0
30+
# Per-batch mean = 0.75 (a global distinct/total would give 0.5).
31+
self.assertAlmostEqual(metric.compute().item(), 0.75, places=6)
32+
33+
def test_empty_batch_skipped(self) -> None:
34+
metric = UniqueRatio()
35+
metric.update(torch.empty(0, 3, dtype=torch.long))
36+
self.assertEqual(metric.count.item(), 0.0)
37+
self.assertTrue(torch.isnan(metric.compute()))
38+
39+
40+
if __name__ == "__main__":
41+
unittest.main()

tzrec/models/sid_model.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
"""BaseSidModel: shared base for semantic-ID generation models."""
13+
14+
from typing import Any, Dict, List, Optional
15+
16+
import torch
17+
import torchmetrics
18+
19+
from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
20+
from tzrec.features.feature import BaseFeature
21+
from tzrec.metrics.unique_ratio import UniqueRatio
22+
from tzrec.models.model import BaseModel
23+
from tzrec.protos.model_pb2 import ModelConfig
24+
25+
26+
class BaseSidModel(BaseModel):
27+
"""Shared base for semantic-ID (SID) generation models.
28+
29+
Factors the structure common to :class:`SidRqvae` (RQ-VAE) and
30+
:class:`SidRqkmeans` (residual K-Means):
31+
32+
- the shared config fields every SID proto carries —
33+
``embedding_feature_name`` (``_embedding_feature_name``), ``input_dim``
34+
(``_input_dim``), ``normalize_residuals`` (``_normalize_residuals``),
35+
and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``),
36+
- reading the item-embedding feature out of ``Batch.dense_features``,
37+
- the eval metrics every SID model reports — reconstruction ``mse`` and
38+
``unique_sid_ratio`` (mean per-batch unique-SID ratio, a diversity
39+
proxy).
40+
41+
Subclasses build their quantizer in ``__init__`` (after calling
42+
``super().__init__``) and implement :meth:`predict` and :meth:`loss`.
43+
They extend :meth:`init_metric` (via ``super()``) and implement
44+
:meth:`update_metric` to populate the registered metrics
45+
(:meth:`update_train_metric` defaults to a no-op).
46+
47+
Args:
48+
model_config (ModelConfig): an instance of ModelConfig.
49+
features (list): list of features.
50+
labels (list): list of label names.
51+
sample_weights (list): sample weight names.
52+
"""
53+
54+
def __init__(
55+
self,
56+
model_config: ModelConfig,
57+
features: List[BaseFeature],
58+
labels: List[str],
59+
sample_weights: Optional[List[str]] = None,
60+
**kwargs: Any,
61+
) -> None:
62+
super().__init__(model_config, features, labels, sample_weights, **kwargs)
63+
64+
cfg = self._model_config
65+
# Config fields shared by every SID model (present on each SID proto
66+
# message): the item-embedding feature, the input dimension, the
67+
# residual-normalization toggle, and the per-layer codebook.
68+
self._embedding_feature_name = cfg.embedding_feature_name
69+
self._input_dim = cfg.input_dim
70+
self._normalize_residuals = cfg.normalize_residuals
71+
72+
assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]"
73+
self._n_embed_list = list(cfg.codebook)
74+
self._n_layers = len(self._n_embed_list)
75+
76+
def _extract_feature(
77+
self, batch: Batch, feature_name: Optional[str] = None
78+
) -> torch.Tensor:
79+
"""Extract a named dense feature from ``Batch.dense_features``.
80+
81+
Args:
82+
batch (Batch): input batch data.
83+
feature_name (str, optional): feature name to extract.
84+
Defaults to ``self._embedding_feature_name``.
85+
"""
86+
if feature_name is None:
87+
feature_name = self._embedding_feature_name
88+
kt = batch.dense_features[BASE_DATA_GROUP]
89+
return kt[feature_name]
90+
91+
def init_loss(self) -> None:
92+
"""Initialize loss modules.
93+
94+
SID models compute their losses internally and pass them through
95+
``predictions``; there is no external loss module to register.
96+
"""
97+
pass
98+
99+
def init_metric(self) -> None:
100+
"""Initialize the eval metrics shared by all SID models.
101+
102+
``mse``: reconstruction error (input vs. quantized / decoded).
103+
``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows /
104+
batch size; a batch-size-sensitive diversity proxy, not global
105+
coverage). Subclasses call ``super().init_metric()`` then add extras.
106+
"""
107+
self._metric_modules["mse"] = torchmetrics.MeanSquaredError()
108+
self._metric_modules["unique_sid_ratio"] = UniqueRatio()
109+
110+
def update_train_metric(
111+
self,
112+
predictions: Dict[str, torch.Tensor],
113+
batch: Batch,
114+
) -> None:
115+
"""Update train-path metric state.
116+
117+
Default is a no-op: K-Means has no train-time codes, so only models
118+
with a meaningful train signal (RQ-VAE) override this.
119+
"""
120+
return
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

0 commit comments

Comments
 (0)