Skip to content

Commit 2c8ffa4

Browse files
WhiteSwan1claude
andcommitted
[refactor] SID: add ResidualQuantizer / BaseSidModel base classes
First of three PRs splitting the Semantic-ID models onto a shared base. This one is purely additive — only the backend-agnostic foundation, with no concrete quantizer or model and no edits to existing files. RQ-KMeans follows in PR2, RQ-VAE in PR3. What this adds: - ResidualQuantizer (abstract): owns the state every residual quantizer shares — embed_dim, per-layer codebook sizes (normalize_n_embed), residual-normalization flag, and the layer list — and the shared residual walk: _residual_pass drives the concrete get_codes / decode_codes / output_dim. Subclasses implement just the per-layer primitives _quantize_layer (encode) and _lookup_code (decode), plus forward and get_codebook_embeddings. - BaseSidModel (abstract): the shared SID model scaffold — embedding-feature extraction, loss/metric initialization, and shared config parsing — that SidRqkmeans and SidRqvae will subclass. - types.py: QuantizeForwardMode, QuantizeOutput, ResidualQuantizerOutput. Tests: normalize_n_embed and the abstract-base contract (shared state present, backend primitives raise). No proto changes and no edits to existing modules. The concrete SidRqkmeans / SidRqvae models (and their K-Means / VQ layers, kmeans helpers, protos, and the BaseModel.on_train_end hook) arrive in PR2 / PR3. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 095f6d2 commit 2c8ffa4

5 files changed

Lines changed: 560 additions & 0 deletions

File tree

tzrec/models/sid_model.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
"""BaseSidModel: shared base for semantic-ID generation models."""
13+
14+
from typing import Any, List, Optional
15+
16+
import torch
17+
import torchmetrics
18+
19+
from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
20+
from tzrec.features.feature import BaseFeature
21+
from tzrec.models.model import BaseModel
22+
from tzrec.protos.model_pb2 import ModelConfig
23+
24+
25+
class BaseSidModel(BaseModel):
26+
"""Shared base for semantic-ID (SID) generation models.
27+
28+
Factors the structure common to :class:`SidRqvae` (RQ-VAE) and
29+
:class:`SidRqkmeans` (residual K-Means):
30+
31+
- the shared config fields every SID proto carries —
32+
``embedding_feature_name`` (``_embedding_feature_name``), ``input_dim``
33+
(``_input_dim``), ``normalize_residuals`` (``_normalize_residuals``),
34+
and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``),
35+
- reading the item-embedding feature out of ``Batch.dense_features``,
36+
- the eval metrics every SID model reports — reconstruction ``mse`` and
37+
``unique_sid_ratio`` (codebook coverage).
38+
39+
Subclasses build their quantizer in ``__init__`` (after calling
40+
``super().__init__``) and implement :meth:`predict` and :meth:`loss`.
41+
They extend :meth:`init_metric` / :meth:`update_metric` with any
42+
backend-specific metrics.
43+
44+
Args:
45+
model_config (ModelConfig): an instance of ModelConfig.
46+
features (list): list of features.
47+
labels (list): list of label names.
48+
sample_weights (list): sample weight names.
49+
"""
50+
51+
def __init__(
52+
self,
53+
model_config: ModelConfig,
54+
features: List[BaseFeature],
55+
labels: List[str],
56+
sample_weights: Optional[List[str]] = None,
57+
**kwargs: Any,
58+
) -> None:
59+
super().__init__(model_config, features, labels, sample_weights, **kwargs)
60+
61+
cfg = self._model_config
62+
# Config fields shared by every SID model (present on each SID proto
63+
# message): the item-embedding feature, the input dimension, the
64+
# residual-normalization toggle, and the per-layer codebook.
65+
self._embedding_feature_name = cfg.embedding_feature_name
66+
self._input_dim = cfg.input_dim
67+
self._normalize_residuals = cfg.normalize_residuals
68+
69+
assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]"
70+
self._n_embed_list = list(cfg.codebook)
71+
self._n_layers = len(self._n_embed_list)
72+
73+
def _extract_feature(
74+
self, batch: Batch, feature_name: Optional[str] = None
75+
) -> torch.Tensor:
76+
"""Extract a named dense feature from ``Batch.dense_features``.
77+
78+
Args:
79+
batch (Batch): input batch data.
80+
feature_name (str, optional): feature name to extract.
81+
Defaults to ``self._embedding_feature_name``.
82+
"""
83+
if feature_name is None:
84+
feature_name = self._embedding_feature_name
85+
kt = batch.dense_features[BASE_DATA_GROUP]
86+
return kt[feature_name]
87+
88+
def init_loss(self) -> None:
89+
"""Initialize loss modules.
90+
91+
SID models compute their losses internally and pass them through
92+
``predictions``; there is no external loss module to register.
93+
"""
94+
pass
95+
96+
def init_metric(self) -> None:
97+
"""Initialize the eval metrics shared by all SID models.
98+
99+
``mse``: reconstruction error (input vs. quantized / decoded).
100+
``unique_sid_ratio``: codebook coverage = unique SIDs / batch size.
101+
Subclasses call ``super().init_metric()`` then add their extras.
102+
"""
103+
self._metric_modules["mse"] = torchmetrics.MeanMetric()
104+
self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric()
105+
106+
def update_train_metric(
107+
self,
108+
predictions: dict,
109+
batch: Batch,
110+
) -> None:
111+
"""Update train-path metric state.
112+
113+
Default is a no-op: K-Means has no train-time codes, so only models
114+
with a meaningful train signal (RQ-VAE) override this.
115+
"""
116+
return
117+
118+
def _update_unique_sid_ratio(self, codes: torch.Tensor) -> None:
119+
"""Update the codebook-coverage metric (unique SIDs / batch size).
120+
121+
Args:
122+
codes (Tensor): semantic-ID codes, shape (B, n_layers).
123+
"""
124+
B = codes.shape[0]
125+
if B == 0: # empty final shard under DDP/TorchRec
126+
return
127+
unique_sids = torch.unique(codes, dim=0).shape[0]
128+
self._metric_modules["unique_sid_ratio"].update(unique_sids / B)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from tzrec.modules.sid_generation.residual_quantizer import (
13+
ResidualQuantizer,
14+
)
15+
from tzrec.modules.sid_generation.types import (
16+
QuantizeForwardMode,
17+
QuantizeOutput,
18+
ResidualQuantizerOutput,
19+
)
20+
21+
__all__ = [
22+
"QuantizeForwardMode",
23+
"QuantizeOutput",
24+
"ResidualQuantizerOutput",
25+
"ResidualQuantizer",
26+
]
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright (c) 2026, Alibaba Group;
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
"""ResidualQuantizer: abstract base for multi-layer residual quantizers."""
13+
14+
from typing import List, Tuple, Union
15+
16+
import torch
17+
from torch import nn
18+
from torch.nn import functional as F
19+
20+
21+
def normalize_n_embed(n_embed: Union[int, List[int]], n_layers: int) -> List[int]:
22+
"""Broadcast a scalar codebook size to a per-layer list (or validate one).
23+
24+
Args:
25+
n_embed (int|List[int]): codebook size, shared or per-layer.
26+
n_layers (int): number of residual quantization layers.
27+
28+
Returns:
29+
List[int]: per-layer codebook sizes, length ``n_layers``.
30+
"""
31+
if isinstance(n_embed, int):
32+
return [n_embed] * n_layers
33+
assert len(n_embed) == n_layers, (
34+
"length of n_embed and n_layers must be same, "
35+
f"but got {len(n_embed)} vs {n_layers}"
36+
)
37+
return list(n_embed)
38+
39+
40+
class ResidualQuantizer(nn.Module):
41+
"""Abstract base for multi-layer residual quantization.
42+
43+
Shared contract for the two SID quantizer backends — the VQ-based,
44+
gradient-trained :class:`ResidualVectorQuantizer` and the K-Means-based,
45+
offline-FAISS-trained :class:`ResidualKMeansQuantizer`. Both quantize the
46+
residual of the previous layer:
47+
48+
residual_0 = input
49+
for each layer i:
50+
(optionally) residual_i = L2_normalize(residual_i)
51+
code_i, quantized_i = layer_i(residual_i)
52+
residual_{i+1} = residual_i - quantized_i
53+
output = sum of all quantized_i
54+
55+
Semantic ID = (code_0, code_1, ..., code_{n_layers-1}).
56+
57+
This base owns the structural invariants (``embed_dim``, ``n_layers``,
58+
per-layer codebook sizes, residual normalization toggle) and the shared
59+
residual walk (:meth:`_residual_pass`, :meth:`get_codes`,
60+
:meth:`decode_codes`, :meth:`output_dim`). Subclasses build ``self.layers``
61+
and implement the per-layer primitives :meth:`_quantize_layer` (encode) and
62+
:meth:`_lookup_code` (decode), plus :meth:`forward` and
63+
:meth:`get_codebook_embeddings`.
64+
65+
Args:
66+
embed_dim (int): feature / codebook dimension.
67+
n_layers (int): number of residual quantization layers.
68+
n_embed (int|List[int]): codebook size per layer. Default: 256.
69+
normalize_residuals (bool): L2-normalize residuals before each
70+
layer. Default: False.
71+
"""
72+
73+
def __init__(
74+
self,
75+
embed_dim: int,
76+
n_layers: int,
77+
n_embed: Union[int, List[int]] = 256,
78+
normalize_residuals: bool = False,
79+
) -> None:
80+
super().__init__()
81+
self.embed_dim = embed_dim
82+
self.n_layers = n_layers
83+
self.normalize_residuals = normalize_residuals
84+
self.n_embed_list = normalize_n_embed(n_embed, n_layers)
85+
# Subclasses MUST populate this with one quantization layer each.
86+
self.layers: nn.ModuleList = nn.ModuleList()
87+
88+
def output_dim(self) -> int:
89+
"""Output dimension of the module."""
90+
return self.embed_dim
91+
92+
def forward(self, input: torch.Tensor): # noqa: ANN201
93+
"""Assign codes per layer and accumulate the quantized output."""
94+
raise NotImplementedError
95+
96+
def _quantize_layer(
97+
self,
98+
layer_idx: int,
99+
residual: torch.Tensor,
100+
temperature: float = 1.0,
101+
) -> Tuple[torch.Tensor, torch.Tensor]:
102+
"""Assign one layer's codes and look up its quantized vector.
103+
104+
Backend primitive behind the residual walk (encode-direction mirror of
105+
:meth:`_lookup_code`). ``temperature`` is used only by the VQ backend.
106+
107+
Args:
108+
layer_idx (int): quantization layer index.
109+
residual (Tensor): current residual, shape (B, D).
110+
temperature (float): Gumbel-Softmax temperature (VQ only).
111+
112+
Returns:
113+
codes (Tensor): per-layer cluster ids, shape (B,).
114+
quantized (Tensor): the layer's quantized vector, shape (B, D).
115+
"""
116+
raise NotImplementedError
117+
118+
def _residual_pass(
119+
self,
120+
input: torch.Tensor,
121+
temperature: float = 1.0,
122+
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
123+
"""Shared residual walk: per-layer assign, subtract, accumulate.
124+
125+
The quantized vector is subtracted detached (keeps the residual chain
126+
gradient-free) and accumulated (keeps gradient when the backend
127+
supplies it, e.g. VQ).
128+
129+
Args:
130+
input (Tensor): input embeddings, shape (B, D).
131+
temperature (float): forwarded to :meth:`_quantize_layer`.
132+
133+
Returns:
134+
cluster_ids (Tensor): stacked codes, shape (B, n_layers).
135+
aggregated (Tensor): sum of quantized vectors, shape (B, D).
136+
cumulative (List[Tensor]): running sum after each layer
137+
(``cumulative[-1] is aggregated``).
138+
"""
139+
residual = input
140+
all_codes: List[torch.Tensor] = []
141+
cumulative: List[torch.Tensor] = []
142+
aggregated = torch.zeros_like(input)
143+
for i in range(self.n_layers):
144+
if self.normalize_residuals:
145+
residual = F.normalize(residual, dim=-1)
146+
codes, quantized = self._quantize_layer(i, residual, temperature)
147+
all_codes.append(codes)
148+
aggregated = aggregated + quantized
149+
cumulative.append(aggregated)
150+
residual = residual - quantized.detach()
151+
cluster_ids = torch.stack(all_codes, dim=-1) # (B, n_layers)
152+
return cluster_ids, aggregated, cumulative
153+
154+
@torch.no_grad()
155+
def get_codes(self, input: torch.Tensor) -> torch.Tensor:
156+
"""Assign semantic IDs without updating the codebook.
157+
158+
Shared encode-direction mirror of :meth:`decode_codes`.
159+
160+
Args:
161+
input (Tensor): input embeddings, shape (B, D).
162+
163+
Returns:
164+
Tensor: cluster ids, shape (B, n_layers).
165+
"""
166+
cluster_ids, _, _ = self._residual_pass(input)
167+
return cluster_ids
168+
169+
@torch.no_grad()
170+
def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor:
171+
"""Get the codebook (centroid) weights for a specific layer.
172+
173+
Args:
174+
layer_idx (int): index of the quantization layer.
175+
176+
Returns:
177+
Tensor: codebook weights, shape (n_embed, embed_dim).
178+
"""
179+
raise NotImplementedError
180+
181+
def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor:
182+
"""Look up the codebook vectors for ``code_idx`` at ``layer_idx``.
183+
184+
The single backend-specific primitive :meth:`decode_codes` builds on
185+
(VQ reads ``embedding(idx)``, K-Means reads ``centroids[idx]``).
186+
187+
Args:
188+
layer_idx (int): index of the quantization layer.
189+
code_idx (Tensor): codebook indices, shape (B,).
190+
191+
Returns:
192+
Tensor: looked-up codebook vectors, shape (B, embed_dim).
193+
"""
194+
raise NotImplementedError
195+
196+
@torch.no_grad()
197+
def decode_codes(self, codes: torch.Tensor) -> torch.Tensor:
198+
"""Reconstruct embeddings from semantic ID codes (centroid sum).
199+
200+
Args:
201+
codes (Tensor): cluster ids, shape (B, n_layers).
202+
203+
Returns:
204+
Tensor: reconstructed embeddings, shape (B, embed_dim).
205+
"""
206+
# Seed from the first lookup so device and dtype follow the codebook
207+
# (avoids pinning the sum to fp32 under mixed precision). n_layers >= 1
208+
# is guaranteed by the codebook config.
209+
quantized_sum = self._lookup_code(0, codes[:, 0])
210+
for i in range(1, self.n_layers):
211+
quantized_sum = quantized_sum + self._lookup_code(i, codes[:, i])
212+
return quantized_sum

0 commit comments

Comments
 (0)