From 89fad00781a0ddf66a8953c9995359f145d88db8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 02:15:17 +0000 Subject: [PATCH 001/129] [feat] SID: port RQ-VAE and RQ-KMeans semantic-ID generation models Bring the SID-generation stack (from the remove_ema_2 working branch) onto a clean upstream base as the starting point for the base-class abstraction refactor. Net-new files only: - models: sid_rqvae.py, sid_rqkmeans.py (+ tests), _sid_helpers.py - modules/sid_generation: rqvae, residual_quantized, residual_kmeans, kmeans, vector_quantize, clip_loss, types - protos: sid_model.proto + SidRqvae/SidRqkmeans wired into model.proto 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/_sid_helpers.py | 24 ++ tzrec/models/sid_rqkmeans.py | 341 +++++++++++++++ tzrec/models/sid_rqkmeans_test.py | 184 ++++++++ tzrec/models/sid_rqvae.py | 259 +++++++++++ tzrec/models/sid_rqvae_test.py | 405 ++++++++++++++++++ tzrec/modules/sid_generation/__init__.py | 48 +++ tzrec/modules/sid_generation/clip_loss.py | 240 +++++++++++ tzrec/modules/sid_generation/kmeans.py | 290 +++++++++++++ .../modules/sid_generation/residual_kmeans.py | 373 ++++++++++++++++ .../sid_generation/residual_quantized.py | 399 +++++++++++++++++ tzrec/modules/sid_generation/rqvae.py | 372 ++++++++++++++++ tzrec/modules/sid_generation/types.py | 55 +++ .../modules/sid_generation/vector_quantize.py | 264 ++++++++++++ tzrec/protos/model.proto | 5 + tzrec/protos/models/sid_model.proto | 94 ++++ 15 files changed, 3353 insertions(+) create mode 100644 tzrec/models/_sid_helpers.py create mode 100644 tzrec/models/sid_rqkmeans.py create mode 100644 tzrec/models/sid_rqkmeans_test.py create mode 100644 tzrec/models/sid_rqvae.py create mode 100644 tzrec/models/sid_rqvae_test.py create mode 100644 tzrec/modules/sid_generation/__init__.py create mode 100644 tzrec/modules/sid_generation/clip_loss.py create mode 100644 tzrec/modules/sid_generation/kmeans.py create mode 100644 tzrec/modules/sid_generation/residual_kmeans.py create mode 100644 tzrec/modules/sid_generation/residual_quantized.py create mode 100644 tzrec/modules/sid_generation/rqvae.py create mode 100644 tzrec/modules/sid_generation/types.py create mode 100644 tzrec/modules/sid_generation/vector_quantize.py create mode 100644 tzrec/protos/models/sid_model.proto diff --git a/tzrec/models/_sid_helpers.py b/tzrec/models/_sid_helpers.py new file mode 100644 index 000000000..04946003c --- /dev/null +++ b/tzrec/models/_sid_helpers.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared helpers for SID-generation model wrappers.""" + +from typing import List + + +def parse_int_list(s: str) -> List[int]: + """Parse comma-separated int string, e.g. '256,128' -> [256, 128].""" + return [int(x.strip()) for x in s.split(",") if x.strip()] + + +def parse_float_list(s: str) -> List[float]: + """Parse comma-separated float string, e.g. '1.0,0.5' -> [1.0, 0.5].""" + return [float(x.strip()) for x in s.split(",") if x.strip()] diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py new file mode 100644 index 000000000..398e266cc --- /dev/null +++ b/tzrec/models/sid_rqkmeans.py @@ -0,0 +1,341 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidRqkmeans: SID generation model using residual K-Means. + +Training is FAISS-only: ``predict`` collects embeddings into a CPU +buffer; the actual FAISS fit is triggered ONCE after the train_eval +loop ends, via the :meth:`BaseModel.on_train_end` lifecycle hook +(``tzrec.main`` calls ``_model.on_train_end()`` unconditionally). +""" + +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torchmetrics +from google.protobuf.json_format import MessageToDict +from torch import nn + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import BaseFeature +from tzrec.models._sid_helpers import parse_int_list +from tzrec.models.model import BaseModel +from tzrec.modules.sid_generation import RQKMeans +from tzrec.modules.sid_generation.kmeans import recon_diagnostics +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils.logging_util import logger + + +def _coerce_proto_numbers(d: Dict) -> Dict: + """Coerce float-typed integers back to int. + + ``google.protobuf.Struct.number_value`` is always float, but most + ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require + Python ``int``. This helper converts any float that is an exact + integer to ``int`` for downstream consumption. + """ + out: Dict = {} + for k, v in d.items(): + if isinstance(v, float) and v.is_integer(): + out[k] = int(v) + else: + out[k] = v + return out + + +class SidRqkmeans(BaseModel): + """SID generation model using residual K-Means (FAISS-only). + + No gradient-based training. The codebook is built once at the end + of the train_eval loop via a single FAISS K-Means pass over the + embeddings collected during training. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config # SidRqkmeans proto message + self._embedding_feature_name = cfg.embedding_feature_name + + assert cfg.codebook, "codebook must be set, e.g. '256,256,256'" + n_embed_list = parse_int_list(cfg.codebook) + n_layers = len(n_embed_list) + + self._faiss_kwargs = ( + _coerce_proto_numbers(MessageToDict(cfg.faiss_kmeans_kwargs)) + if cfg.HasField("faiss_kmeans_kwargs") + else {} + ) + + self._rqkmeans = RQKMeans( + embed_dim=cfg.input_dim, + n_layers=n_layers, + n_embed=n_embed_list, + normalize_residuals=cfg.normalize_residuals, + faiss_kmeans_kwargs=self._faiss_kwargs, + ) + + # CPU buffer for embeddings collected during training; FAISS + # consumes it in on_train_end() at end-of-loop. + self._offline_buffer: List[torch.Tensor] = [] + + # KMeans has no learnable parameters (centroids use register_buffer). + # Add dummy param to keep optimizer/DDP happy. + self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + + def _extract_embedding(self, batch: Batch) -> torch.Tensor: + """Extract item embedding from Batch.dense_features.""" + kt = batch.dense_features[BASE_DATA_GROUP] + return kt[self._embedding_feature_name] + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Training: buffer embeddings only (codes are dummy until FAISS fits). + Eval/inference (after ``on_train_end``): real predict + lookup. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + embedding = self._extract_embedding(batch) + + # Training: buffer for the end-of-loop FAISS fit and return dummy + # codes — the codebook does not exist yet. + # TODO(perf): .cpu() is a synchronous D2H per step and the buffer + # grows unbounded with steps. Rework to either (a) GPU-resident + # buffer + bulk D2H in on_train_end with size cap, or (b) replace + # the train pass with an inference_mode corpus walk launched from + # on_train_end. Skipped here to avoid OOM-vs-refactor tradeoffs; + # tracked separately. + if self.is_train: + self._offline_buffer.append(embedding.detach().cpu()) + B = embedding.shape[0] + n_layers = self._rqkmeans.quantizer.n_layers + return { + "codes": torch.zeros( + B, n_layers, dtype=torch.long, device=embedding.device + ) + } + + result = self._rqkmeans(embedding) + + predictions: Dict[str, torch.Tensor] = { + "codes": result["codes"], + } + + if self.is_eval: + predictions["quantized"] = result["quantized"] + predictions["input_embedding"] = embedding + + return predictions + + def init_loss(self) -> None: + """Initialize loss modules. + + KMeans has no gradient loss; the codebook is built in + ``on_train_end`` at end of training. + """ + pass + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute loss of the model. + + Returns zero loss to keep TrainWrapper backward happy. + _dummy_param * 0.0 ensures a compute graph exists so DDP + does not complain about unused parameters. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor. + """ + return {"dummy_loss": self._dummy_param.sum() * 0.0} + + def init_metric(self) -> None: + """Initialize metric modules. + + Only eval metrics are registered. During training ``predict`` + returns dummy zero codes (the codebook does not exist yet), so + any train-time metric would be either NaN or trivially constant. + ``compute_train_metric`` therefore returns an empty dict, which + the framework already tolerates. + """ + self._metric_modules["mse"] = torchmetrics.MeanMetric() + self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() + self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() + + def update_train_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + ) -> None: + """No-op — see :meth:`init_metric`.""" + return + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update metric state. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + codes = predictions["codes"] + B = codes.shape[0] + + if "input_embedding" in predictions: + mse, rel = recon_diagnostics( + predictions["input_embedding"], + predictions["quantized"], + ) + self._metric_modules["mse"].update(mse) + self._metric_modules["rel_loss"].update(rel) + + unique_sids = torch.unique(codes, dim=0).shape[0] + self._metric_modules["unique_sid_ratio"].update(unique_sids / B) + + @torch.no_grad() + def on_train_end(self) -> None: + """Trigger one-shot FAISS fit after the train_eval loop ends. + + Overrides :meth:`BaseModel.on_train_end`. Called unconditionally + by ``tzrec.main.train_and_evaluate`` after the training loop + exits. No-op when the buffer is empty. + + DDP behavior: + - rank0: receive local buffers via gather_object, concat, + run FAISS fit, then broadcast centroids to other ranks. + - other ranks: ship local buffer via gather_object(dst=0) + and wait for the broadcast. + """ + is_ddp = ( + dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + ) + + # A local-only empty check would deadlock: the empty rank returns + # while peers block in gather_object below. OR the flag across + # ranks and bail together if any rank is empty. + local_empty = len(self._offline_buffer) == 0 + if is_ddp: + # int32, not bool — NCCL bool support is version-dependent. + flag = torch.tensor( + int(local_empty), + dtype=torch.int32, + device=self._dummy_param.device, + ) + dist.all_reduce(flag, op=dist.ReduceOp.MAX) + any_empty = bool(flag.item()) + else: + any_empty = local_empty + + if any_empty: + if (not is_ddp) or dist.get_rank() == 0: + logger.warning( + "[SidRqkmeans.on_train_end] at least one rank has an " + "empty offline buffer; skipping FAISS fit on all ranks. " + "Did the train_eval loop run, and is the per-rank shard " + "non-empty?" + ) + return + + if is_ddp: + # DDP path: every rank ships its local buffer to rank 0 via + # gather_object (variable-length pickle — fine for this one- + # shot, CPU-resident gather). Only rank 0 holds the corpus, + # so peak memory is O(world_size) on rank 0 and O(1) elsewhere + # (vs O(world_size²) for all_gather_object). + local = torch.cat(self._offline_buffer, dim=0) + del self._offline_buffer + self._offline_buffer = [] + + rank = dist.get_rank() + gathered: Optional[List[Optional[torch.Tensor]]] = ( + [None] * dist.get_world_size() if rank == 0 else None + ) + dist.gather_object(local, gathered, dst=0) + del local + if rank == 0: + assert gathered is not None + full = torch.cat([g for g in gathered if g is not None], dim=0) + del gathered + logger.info( + "[SidRqkmeans.on_train_end] rank0 fitting FAISS " + "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + ) + self._rqkmeans.train_offline(full, verbose=True) + del full + # Broadcast centroids and set the init flag locally on every + # rank. ``_is_initialized`` is a bool buffer and NCCL's bool + # dtype support is inconsistent across versions, so we avoid + # a separate broadcast for it — all ranks enter this block in + # lockstep, so a local fill_() keeps state consistent. + for layer in self._rqkmeans.quantizer.layers: + dist.broadcast(layer.centroids, src=0) + layer._is_initialized.fill_(True) + dist.barrier() + else: + # Single-process path: build the full numpy matrix directly + # from the buffer list, popping each chunk after copy so the + # transient memory high-water mark stays ~= final matrix size + # (instead of 2× when going through torch.cat). + N = sum(t.shape[0] for t in self._offline_buffer) + D = self._offline_buffer[0].shape[1] + logger.info( + "[SidRqkmeans.on_train_end] fitting FAISS on " + "%d samples (D=%d)." % (N, D) + ) + full_np = np.empty((N, D), dtype=np.float32) + offset = 0 + # Pop from the front; each popped tensor is released before + # the next copy so cumulative torch memory shrinks monotonically. + while self._offline_buffer: + t = self._offline_buffer.pop(0) + n = t.shape[0] + # .float().numpy() returns a view sharing storage with + # the fp32 tensor; the subsequent assignment copies into + # full_np, after which ``t`` can be freed. + full_np[offset : offset + n] = t.float().numpy() + offset += n + del t + del self._offline_buffer + self._offline_buffer = [] + + # train_offline takes ownership of ``full_np`` (in-place + # residual updates); drop our reference after the call. + self._rqkmeans.train_offline(full_np, verbose=True) + del full_np diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py new file mode 100644 index 000000000..25b4f0800 --- /dev/null +++ b/tzrec/models/sid_rqkmeans_test.py @@ -0,0 +1,184 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.models.sid_rqkmeans import SidRqkmeans +from tzrec.protos import model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils.state_dict_util import init_parameters + + +def _make_batch(batch_size: int, input_dim: int) -> Batch: + """Create a minimal Batch with dense embedding features.""" + dense_feature = KeyedTensor.from_tensor_list( + keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim)] + ) + return Batch( + dense_features={BASE_DATA_GROUP: dense_feature}, + sparse_features={}, + labels={}, + ) + + +class SidRqkmeansOfflineTest(unittest.TestCase): + """Tests for SidRqkmeans (FAISS-only).""" + + def _create_model(self, input_dim=32, n_layers=2, niter=5): + """Create a SidRqkmeans configured for offline FAISS fit.""" + from google.protobuf.struct_pb2 import Struct + + n_embed_str = ",".join(["16"] * n_layers) + + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + + sid_rqkmeans_cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_str, + normalize_residuals=False, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + ) + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["item_emb"], + group_type=model_pb2.FeatureGroupType.DEEP, + ), + ] + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, + sid_rqkmeans=sid_rqkmeans_cfg, + ) + model = SidRqkmeans(model_config=model_config, features=[], labels=[]) + init_parameters(model, device=torch.device("cpu")) + return model + + def test_proto_parse(self) -> None: + """Verify faiss_kmeans_kwargs are parsed correctly.""" + model = self._create_model() + self.assertEqual(model._faiss_kwargs.get("niter"), 5) + self.assertEqual(model._faiss_kwargs.get("seed"), 1234) + self.assertFalse(model._faiss_kwargs.get("verbose")) + self.assertEqual(model._offline_buffer, []) + + def test_predict_collects_buffer(self) -> None: + """In train mode, predict should append to buffer; never fit.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + for _ in range(4): + batch = _make_batch(B, input_dim) + preds = model.predict(batch) + self.assertIn("codes", preds) + + # Buffer accumulates 4 batches of B samples each + self.assertEqual(len(model._offline_buffer), 4) + total = sum(t.shape[0] for t in model._offline_buffer) + self.assertEqual(total, 4 * B) + # FAISS not yet triggered: layers should be uninitialized + for layer in model._rqkmeans.quantizer.layers: + self.assertFalse(layer.is_initialized) + + def test_on_train_end_runs_faiss(self) -> None: + """on_train_end triggers FAISS fit and clears buffer.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + # Accumulate enough samples (FAISS K-Means needs at least K points) + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + self.assertGreater(len(model._offline_buffer), 0) + + # Trigger one-shot FAISS fit + model.on_train_end() + + # Buffer should be cleared + self.assertEqual(model._offline_buffer, []) + # All layers should be initialized + centroids non-zero + for layer in model._rqkmeans.quantizer.layers: + self.assertTrue(bool(layer._is_initialized.item())) + self.assertGreater(layer.centroids.abs().sum().item(), 0.0) + + # After fit, predict on eval should produce valid codes + model.eval() + preds = model.predict(_make_batch(B, input_dim)) + codes = preds["codes"] + self.assertEqual(codes.shape, (B, 2)) + self.assertTrue((codes >= 0).all() and (codes < 16).all()) + + def test_on_train_end_noop_on_empty_buffer(self) -> None: + """on_train_end on an empty buffer is a warned no-op.""" + model = self._create_model() + model.on_train_end() # should not raise + + def test_post_fit_checkpoint_round_trips(self) -> None: + """Fit → save state_dict → load into fresh instance → predict. + + After loading, ``predict`` must return real (non-zero) codes — + the centroids and the ``_is_initialized`` flag both need to come + through the state_dict. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + src = self._create_model(input_dim=input_dim) + src.train() + for _ in range(8): + src.predict(_make_batch(B, input_dim)) + src.on_train_end() + sd = src.state_dict() + + dst = self._create_model(input_dim=input_dim) + dst.load_state_dict(sd) + dst.eval() + codes = dst.predict(_make_batch(B, input_dim))["codes"] + self.assertGreater( + codes.abs().sum().item(), + 0, + "post-fit checkpoint resume produced all-zero codes", + ) + + def test_mid_fit_checkpoint_rejected_on_load(self) -> None: + """Tampered state (_is_initialized=True + zero centroids) raises.""" + model = self._create_model() + sd = model.state_dict() + # Simulate a checkpoint that captured the flag mid-fit (before + # load_centroids_ ran): True flag, zero centroids. + layer0_prefix = next( + k.rsplit("._is_initialized", 1)[0] + for k in sd + if k.endswith("._is_initialized") + ) + sd[f"{layer0_prefix}._is_initialized"] = torch.tensor(True) + + fresh = self._create_model() + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py new file mode 100644 index 000000000..0520c8f23 --- /dev/null +++ b/tzrec/models/sid_rqvae.py @@ -0,0 +1,259 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidRqvae: SID generation model using RQ-VAE (Encoder + VQ + Decoder). + +End-to-end differentiable training with reconstruction loss +and commitment loss. Optionally supports CLIP contrastive learning. +""" + +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +import torchmetrics + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import BaseFeature +from tzrec.models._sid_helpers import parse_float_list, parse_int_list +from tzrec.models.model import BaseModel +from tzrec.modules.sid_generation import RQVAE +from tzrec.protos.model_pb2 import ModelConfig + + +class SidRqvae(BaseModel): + """SID generation model using RQ-VAE (Encoder + VQ + Decoder). + + End-to-end differentiable training with reconstruction loss + and commitment loss. Optionally supports CLIP contrastive learning. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config # SidRqvae proto message + self._embedding_feature_name = cfg.embedding_feature_name + self._loss_type = cfg.loss_type + self._use_clip = cfg.HasField("clip_config") + self._clip_feature_name = ( + cfg.clip_config.clip_feature_name if self._use_clip else None + ) + self._is_clip_pair_feature_name = ( + cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None + ) + + hidden_dims = parse_int_list(cfg.hidden_dims) if cfg.hidden_dims else None + # Only forward latent_weight when proto sets it; otherwise let + # RQVAE / ResidualQuantized apply their signature default (1.0, 0.5). + rqvae_extra: Dict[str, Any] = {} + if cfg.latent_weight: + rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight) + + assert cfg.codebook, "codebook must be set, e.g. '256,256,256'" + n_embed_list = parse_int_list(cfg.codebook) + n_layers = len(n_embed_list) + + use_sinkhorn = True + sinkhorn_iters = 5 + sinkhorn_epsilon = 10.0 + if cfg.HasField("sinkhorn_config"): + use_sinkhorn = cfg.sinkhorn_config.enabled + sinkhorn_iters = cfg.sinkhorn_config.iters + sinkhorn_epsilon = cfg.sinkhorn_config.epsilon + + self._rqvae = RQVAE( + input_dim=cfg.input_dim, + embed_dim=cfg.embed_dim, + hidden_dims=hidden_dims, + n_layers=n_layers, + n_embed=n_embed_list, + forward_mode=cfg.forward_mode, + normalize_residuals=cfg.normalize_residuals, + distance_type=cfg.distance_type, + commitment_loss=cfg.commitment_loss, + rotation_trick=cfg.rotation_trick, + kmeans_init=cfg.kmeans_init, + use_sinkhorn=use_sinkhorn, + sinkhorn_iters=sinkhorn_iters, + sinkhorn_epsilon=sinkhorn_epsilon, + loss_type=cfg.loss_type, + use_clip=self._use_clip, + **rqvae_extra, + ) + + def _extract_feature( + self, batch: Batch, feature_name: Optional[str] = None + ) -> torch.Tensor: + """Extract a named feature from Batch.dense_features. + + Args: + batch (Batch): input batch data. + feature_name (str, optional): feature name to extract. + Defaults to self._embedding_feature_name. + """ + if feature_name is None: + feature_name = self._embedding_feature_name + kt = batch.dense_features[BASE_DATA_GROUP] + return kt[feature_name] + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + embedding = self._extract_feature(batch) + + if self._use_clip: + return self._predict_mixed(embedding, batch) + else: + return self._predict_rqvae(embedding) + + def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: + """Standard RQ-VAE: encode -> quantize -> decode -> loss.""" + result = self._rqvae.forward_rqvae(embedding) + + predictions: Dict[str, torch.Tensor] = { + "codes": result["codes"], + } + + if self.is_train or self.is_eval: + predictions["quantized"] = result["quantized"] + predictions["x_hat"] = result["x_hat"] + predictions["reconstruction_loss"] = result["reconstruction_loss"] + predictions["quantization_loss"] = result["quantization_loss"] + + return predictions + + def _predict_mixed( + self, embedding: torch.Tensor, batch: Batch + ) -> Dict[str, torch.Tensor]: + """Mixed recon + CLIP: extract fea2 and clip_mask, call forward_mixed.""" + # Inference skips the dual path: fea2 / clip_mask aren't needed + # when we only emit codes. + if self._is_inference: + result = self._rqvae.forward_rqvae(embedding) + return {"codes": result["codes"]} + + fea2 = self._extract_feature(batch, self._clip_feature_name) + + is_clip_pair_raw = self._extract_feature(batch, self._is_clip_pair_feature_name) + clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 + + result = self._rqvae.forward_mixed(embedding, fea2, clip_mask) + + predictions: Dict[str, torch.Tensor] = { + "codes": result["codes"], + "quantized": result["quantized"], + "x_hat": result["x_hat"], + "recon_loss": result["recon_loss"], + "clip_loss": result["clip_loss"], + "commitment_loss": result["commitment_loss"], + } + return predictions + + def init_loss(self) -> None: + """Initialize loss modules. + + Reconstruction loss and commitment loss are computed internally + by RQVAE and passed through predictions. No external loss module needed. + """ + pass + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute loss of the model. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor. + """ + losses: Dict[str, torch.Tensor] = {} + if self._use_clip: + losses["recon_loss"] = predictions["recon_loss"] + losses["clip_loss"] = predictions["clip_loss"] + losses["commitment_loss"] = predictions["commitment_loss"] + else: + losses["reconstruction_loss"] = predictions["reconstruction_loss"] + losses["quantization_loss"] = predictions["quantization_loss"] + return losses + + def init_metric(self) -> None: + """Initialize metric modules.""" + self._metric_modules["mse"] = torchmetrics.MeanMetric() + self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() + + # Loss values are already logged by the framework via loss(); only + # quantization quality needs the train-path metric. unique_sid_ratio + # is intentionally eval-only: torch.unique(codes, dim=0).shape[0] + # forces a GPU->host sync every step, and codebook coverage is a + # diagnostic, not a training signal. + self._train_metric_modules["mse"] = torchmetrics.MeanMetric() + + def update_train_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + ) -> None: + """Update train metric state. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + """ + if "x_hat" in predictions: + embedding = self._extract_feature(batch) + mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") + self._train_metric_modules["mse"].update(mse) + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update metric state. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + codes = predictions["codes"] + B = codes.shape[0] + + if "x_hat" in predictions: + embedding = self._extract_feature(batch) + mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") + self._metric_modules["mse"].update(mse) + + unique_sids = torch.unique(codes, dim=0).shape[0] + self._metric_modules["unique_sid_ratio"].update(unique_sids / B) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py new file mode 100644 index 000000000..a87e58133 --- /dev/null +++ b/tzrec/models/sid_rqvae_test.py @@ -0,0 +1,405 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.models.sid_rqvae import SidRqvae +from tzrec.protos import model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils.state_dict_util import init_parameters + + +def _make_batch( + batch_size: int, + input_dim: int, + feature_name: str = "item_emb", + extra_features: dict = None, +) -> Batch: + """Create a minimal Batch with dense embedding features.""" + keys = [feature_name] + tensors = [torch.randn(batch_size, input_dim)] + if extra_features: + for k, v in extra_features.items(): + keys.append(k) + tensors.append(v) + dense_feature = KeyedTensor.from_tensor_list(keys=keys, tensors=tensors) + return Batch( + dense_features={BASE_DATA_GROUP: dense_feature}, + sparse_features={}, + labels={}, + ) + + +class SidRqvaeTest(unittest.TestCase): + """Tests for SidRqvae model.""" + + def _create_model(self, use_clip=False, input_dim=32, embed_dim=8, n_layers=2): + """Helper to create a SidRqvae model with minimal config.""" + n_embed_str = ",".join(["16"] * n_layers) + sid_rqvae_cfg = sid_model_pb2.SidRqvae( + input_dim=input_dim, + embed_dim=embed_dim, + codebook=n_embed_str, + forward_mode="ste", + loss_type="mse", + kmeans_init=False, + embedding_feature_name="item_emb", + ) + if use_clip: + sid_rqvae_cfg.clip_config.CopyFrom( + sid_model_pb2.ClipConfig( + clip_feature_name="image_emb", + is_clip_pair_feature_name="is_clip_pair", + ) + ) + + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["item_emb"], + group_type=model_pb2.FeatureGroupType.DEEP, + ), + ] + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, + sid_rqvae=sid_rqvae_cfg, + ) + model = SidRqvae(model_config=model_config, features=[], labels=[]) + init_parameters(model, device=torch.device("cpu")) + return model + + def test_rqvae_train_mode(self) -> None: + """Test SidRqvae in train mode: predict -> loss -> metric.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.train() + model.init_loss() + model.init_metric() + + batch = _make_batch(B, input_dim) + predictions = model.predict(batch) + + # Train mode should return all fields + self.assertIn("codes", predictions) + self.assertIn("quantized", predictions) + self.assertIn("x_hat", predictions) + self.assertIn("reconstruction_loss", predictions) + self.assertIn("quantization_loss", predictions) + self.assertEqual(predictions["codes"].shape[0], B) + + # Loss should return reconstruction_loss + quantization_loss + losses = model.loss(predictions, batch) + self.assertIn("reconstruction_loss", losses) + self.assertIn("quantization_loss", losses) + + # Total loss should be a scalar and have grad + total_loss = sum(losses.values()) + self.assertTrue(total_loss.requires_grad) + + # Metric update should not raise + model.update_metric(predictions, batch, losses) + metrics = model.compute_metric() + self.assertIn("mse", metrics) + self.assertIn("unique_sid_ratio", metrics) + + def test_rqvae_eval_mode(self) -> None: + """Test SidRqvae in eval mode: predict returns all fields.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.eval() + + batch = _make_batch(B, input_dim) + predictions = model.predict(batch) + + # Eval mode (not inference) should return all fields + self.assertIn("codes", predictions) + self.assertIn("quantized", predictions) + self.assertIn("x_hat", predictions) + self.assertIn("reconstruction_loss", predictions) + self.assertIn("quantization_loss", predictions) + + def test_rqvae_inference_mode(self) -> None: + """Test SidRqvae in inference mode: only codes returned.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.eval() + model.set_is_inference(True) + + batch = _make_batch(B, input_dim) + predictions = model.predict(batch) + + # Inference mode should only return codes + self.assertIn("codes", predictions) + self.assertNotIn("x_hat", predictions) + self.assertNotIn("reconstruction_loss", predictions) + + def test_rqvae_clip_mode(self) -> None: + """Test SidRqvae with CLIP mixed mode (mixed recon + clip batch).""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.train() + model.init_loss() + + # Build mixed batch: first half recon, second half clip. + # With the explicit is_clip_pair column the actual tensor values + # no longer matter — the flag column drives routing. + item_emb = torch.randn(B, input_dim) + image_emb = torch.randn(B, input_dim) + is_clip_pair = torch.zeros(B, 1) + is_clip_pair[B // 2 :] = 1.0 # clip rows + + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[item_emb, image_emb, is_clip_pair], + ) + }, + sparse_features={}, + labels={}, + ) + + predictions = model.predict(batch) + + # Mixed mode should return recon_loss, clip_loss, commitment_loss + self.assertIn("codes", predictions) + self.assertIn("recon_loss", predictions) + self.assertIn("clip_loss", predictions) + self.assertIn("commitment_loss", predictions) + self.assertIn("x_hat", predictions) + self.assertEqual(predictions["codes"].shape[0], B) + + # Loss should return all three + losses = model.loss(predictions, batch) + self.assertIn("recon_loss", losses) + self.assertIn("clip_loss", losses) + self.assertIn("commitment_loss", losses) + + total_loss = sum(losses.values()) + self.assertTrue(total_loss.requires_grad) + + # Backward should work + total_loss.backward() + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_rqvae_clip_all_recon(self) -> None: + """Test mixed mode with all-recon batch (edge case).""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.train() + model.init_loss() + + # All recon: is_clip_pair = 0 everywhere + item_emb = torch.randn(B, input_dim) + image_emb = torch.randn(B, input_dim) + is_clip_pair = torch.zeros(B, 1) + + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[item_emb, image_emb, is_clip_pair], + ) + }, + sparse_features={}, + labels={}, + ) + + predictions = model.predict(batch) + model.loss(predictions, batch) + + # clip_loss should be 0 (no clip rows) + self.assertEqual(predictions["clip_loss"].item(), 0.0) + # recon_loss should be > 0 + self.assertGreater(predictions["recon_loss"].item(), 0.0) + + def test_rqvae_clip_all_clip(self) -> None: + """Test mixed mode with all-clip batch (edge case).""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.train() + model.init_loss() + + # All clip: is_clip_pair = 1 everywhere + item_emb = torch.randn(B, input_dim) + image_emb = torch.randn(B, input_dim) + is_clip_pair = torch.ones(B, 1) + + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[item_emb, image_emb, is_clip_pair], + ) + }, + sparse_features={}, + labels={}, + ) + + predictions = model.predict(batch) + model.loss(predictions, batch) + + # recon_loss should be 0 (no recon rows) + self.assertEqual(predictions["recon_loss"].item(), 0.0) + # clip_loss should be > 0 + self.assertGreater(predictions["clip_loss"].item(), 0.0) + + def test_rqvae_backward(self) -> None: + """Test that backward pass works without errors.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.train() + model.init_loss() + + batch = _make_batch(B, input_dim) + predictions = model.predict(batch) + losses = model.loss(predictions, batch) + total_loss = sum(losses.values()) + total_loss.backward() + + # Encoder params should have gradients + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_clip_mask_uses_flag_not_equality(self) -> None: + """The is_clip_pair flag, not bit-exact equality, drives routing. + + Build a batch where ``image_emb == item_emb`` numerically but + ``is_clip_pair=1``: row must route to the CLIP branch (under the + old bit-exact logic it would have been silently relabeled recon). + """ + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.train() + model.init_loss() + + item_emb = torch.randn(B, input_dim) + image_emb = item_emb.clone() # bit-identical + is_clip_pair = torch.ones(B, 1) # but flagged as clip + + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[item_emb, image_emb, is_clip_pair], + ) + }, + sparse_features={}, + labels={}, + ) + + predictions = model.predict(batch) + # All rows flagged as clip -> recon_loss should be 0, clip_loss > 0 + self.assertEqual(predictions["recon_loss"].item(), 0.0) + self.assertGreater(predictions["clip_loss"].item(), 0.0) + + def test_commitment_loss_l1_branch(self) -> None: + """Verify the new commitment_loss='l1' branch runs end-to-end. + + Previously ``"l1"`` silently fell through to the L2 branch. + """ + from tzrec.modules.sid_generation.residual_quantized import ( + ResidualQuantized, + ) + + torch.manual_seed(0) + rq = ResidualQuantized( + embed_dim=8, + n_layers=2, + n_embed=4, + forward_mode="ste", + commitment_loss="l1", + kmeans_init=False, + use_sinkhorn=False, + ) + # Stub the codebook to known centroids so the result is reproducible. + for layer in rq.layers: + torch.nn.init.normal_(layer.embedding.weight, std=0.1) + + x = torch.randn(4, 8, requires_grad=True) + out = rq(x) + # Loss must be a finite scalar with gradient flowing back into x. + self.assertTrue(torch.isfinite(out.quantization_loss)) + out.quantization_loss.backward() + self.assertIsNotNone(x.grad) + + def test_sinkhorn_config_enabled_false(self) -> None: + """``sinkhorn_config { enabled: false }`` must turn Sinkhorn off. + + Previously ``use_sinkhorn`` was hard-coded ``True`` and the proto + block was honored only for iters/epsilon. + """ + n_embed_str = ",".join(["16"] * 2) + sid_rqvae_cfg = sid_model_pb2.SidRqvae( + input_dim=32, + embed_dim=8, + codebook=n_embed_str, + forward_mode="ste", + loss_type="mse", + kmeans_init=False, + embedding_feature_name="item_emb", + ) + sid_rqvae_cfg.sinkhorn_config.CopyFrom( + sid_model_pb2.SinkhornConfig(enabled=False) + ) + feature_groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["item_emb"], + group_type=model_pb2.FeatureGroupType.DEEP, + ), + ] + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, + sid_rqvae=sid_rqvae_cfg, + ) + model = SidRqvae(model_config=model_config, features=[], labels=[]) + init_parameters(model, device=torch.device("cpu")) + + for layer in model._rqvae.quantizer.layers: + self.assertFalse(layer.use_sinkhorn) + + def test_sinkhorn_config_default_enabled(self) -> None: + """Omitting ``sinkhorn_config`` preserves on-by-default behavior. + + Back-compat for legacy configs that never set the sub-config. + """ + model = self._create_model() # no sinkhorn_config set + for layer in model._rqvae.quantizer.layers: + self.assertTrue(layer.use_sinkhorn) + + def test_commitment_loss_invalid_raises(self) -> None: + """ResidualQuantized rejects unknown commitment_loss spellings.""" + from tzrec.modules.sid_generation.residual_quantized import ( + ResidualQuantized, + ) + + with self.assertRaisesRegex(AssertionError, "commitment_loss"): + ResidualQuantized( + embed_dim=8, + n_layers=2, + n_embed=4, + commitment_loss="bogus", + use_sinkhorn=False, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py new file mode 100644 index 000000000..4d4a5d5f2 --- /dev/null +++ b/tzrec/modules/sid_generation/__init__.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tzrec.modules.sid_generation.clip_loss import ( + GatherLayer, +) +from tzrec.modules.sid_generation.kmeans import ( + KMeansLayer, +) +from tzrec.modules.sid_generation.residual_kmeans import ( + ResidualKMeans, + RQKMeans, +) +from tzrec.modules.sid_generation.residual_quantized import ( + ResidualQuantized, +) +from tzrec.modules.sid_generation.rqvae import ( + RQVAE, +) +from tzrec.modules.sid_generation.types import ( + QuantizeForwardMode, + QuantizeOutput, + ResidualQuantizedOutput, +) +from tzrec.modules.sid_generation.vector_quantize import ( + VectorQuantize, +) + +__all__ = [ + "QuantizeForwardMode", + "QuantizeOutput", + "ResidualQuantizedOutput", + "VectorQuantize", + "GatherLayer", + "ResidualQuantized", + "RQVAE", + "KMeansLayer", + "ResidualKMeans", + "RQKMeans", +] diff --git a/tzrec/modules/sid_generation/clip_loss.py b/tzrec/modules/sid_generation/clip_loss.py new file mode 100644 index 000000000..c3a020fd5 --- /dev/null +++ b/tzrec/modules/sid_generation/clip_loss.py @@ -0,0 +1,240 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLIP contrastive learning loss with distributed all-gather support.""" + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + + +class GatherLayer(torch.autograd.Function): + """Gather tensors from all workers with gradient support. + + Standard ``dist.all_gather`` detaches gradients; this custom + ``autograd.Function`` keeps the computation graph connected so + that contrastive losses can backpropagate through gathered tensors. + """ + + @staticmethod + def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: + """All-gather ``x`` across ranks, returning one tensor per rank.""" + output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads: torch.Tensor) -> torch.Tensor: + """Sum-reduce the per-rank grads and return this rank's slice. + + ``all_reduce`` is sum, so reducing only this rank's slice gives + the same result as stacking + reducing + slicing, but avoids + materialising the full ``(world_size, B, D)`` buffer. + """ + grad_local = grads[dist.get_rank()].contiguous() + dist.all_reduce(grad_local) + return grad_local + + +def _all_gather_with_grad( + tensors: List[torch.Tensor], +) -> List[torch.Tensor]: + """All-gather tensors across distributed workers with gradient support. + + In single-process mode, returns input tensors unchanged. + In multi-process mode, uses GatherLayer for backward-compatible + all_gather. + + Args: + tensors (List[Tensor]): list of tensors to gather. + + Returns: + List[Tensor]: gathered tensors, each (world_size * B, ...). + """ + if not dist.is_initialized() or dist.get_world_size() == 1: + return tensors + + gathered: List[torch.Tensor] = [] + for tensor in tensors: + tensor_all = GatherLayer.apply(tensor) + gathered.append(torch.cat(tensor_all, dim=0)) + return gathered + + +class MaskedCLIPLoss(nn.Module): + """Masked CLIP loss for mixed recon+clip batches. + + In a mixed batch, recon rows (clip_mask=False) should not + contribute to CLIP loss, and recon columns should not serve as + negatives. This module applies row and column masks to achieve + selective contrastive learning without data-dependent branching, + ensuring ``torch.compile`` compatibility. + + Input dict keys: + 'image_embed': (B, D) quantized output of first feature + 'text_embed': (B, D) quantized output of second feature + 'image_embed_ori': (B, D) original embedding of first feature + 'text_embed_ori': (B, D) original embedding of second feature + 'logit_scale_self': scalar self-contrast temperature + 'logit_scale_cl': scalar cross-modal contrast temperature + 'logit_scale': scalar original feature contrast temperature + + Output dict keys: + 'clip_loss': scalar mean of three losses (self/ori/cl) + 'clip_acc': scalar contrast accuracy (%); 0 during training + 'loss_self': scalar quantized vs quantized + 'loss_ori': scalar quantized vs original + 'loss_cl': scalar quantized vs counterpart original + """ + + def __init__(self) -> None: + super().__init__() + self.labels: Optional[torch.Tensor] = None + self.last_local_batch_size: Optional[int] = None + self._rank = dist.get_rank() if dist.is_initialized() else 0 + + @staticmethod + def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor: + """All-gather bool mask across distributed workers.""" + if not dist.is_initialized() or dist.get_world_size() == 1: + return mask + mask_list = [torch.zeros_like(mask) for _ in range(dist.get_world_size())] + dist.all_gather(mask_list, mask) + return torch.cat(mask_list, dim=0) + + def _masked_cross_entropy( + self, + logits_i: torch.Tensor, + logits_t: torch.Tensor, + safe_labels: torch.Tensor, + clip_mask: torch.Tensor, + ) -> torch.Tensor: + """Masked cross-entropy on column-masked logits, row-masked average. + + Args: + logits_i: (B, B_global) column-masked logits (image branch). + logits_t: (B, B_global) column-masked logits (text branch). + safe_labels: (B,) labels with recon rows fallback to safe col. + clip_mask: (B,) bool, True = clip row. + """ + ce_i = F.cross_entropy(logits_i, safe_labels, reduction="none") + ce_t = F.cross_entropy(logits_t, safe_labels, reduction="none") + # NaN can occur when all logits are -inf (all-recon edge case) + ce_i = torch.nan_to_num(ce_i, nan=0.0) + ce_t = torch.nan_to_num(ce_t, nan=0.0) + + n_valid = clip_mask.float().sum().clamp(min=1) + return ((ce_i + ce_t) * clip_mask.float()).sum() / (2 * n_valid) + + def forward( + self, + outputs: Dict[str, torch.Tensor], + clip_mask: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward with mask. + + Args: + outputs: feature dict, see class docstring. + clip_mask: (B,) bool, True = clip sample. + """ + image_embed = outputs["image_embed"] + text_embed = outputs["text_embed"] + image_embed_ori = outputs["image_embed_ori"] + text_embed_ori = outputs["text_embed_ori"] + logit_scale = outputs["logit_scale"] + logit_scale_self = outputs["logit_scale_self"] + logit_scale_cl = outputs["logit_scale_cl"] + + local_batch_size = image_embed.size(0) + + # Update labels when batch size changes (multi-GPU offset) + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * self._rank + torch.arange( + local_batch_size, device=image_embed.device + ) + self.last_local_batch_size = local_batch_size + + # L2 normalize quantized features + image_embed = F.normalize(image_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + + # All-gather across GPUs (with gradient support) + image_embed_all, text_embed_all = _all_gather_with_grad( + [image_embed, text_embed] + ) + image_embed_all_ori, text_embed_all_ori = _all_gather_with_grad( + [image_embed_ori, text_embed_ori] + ) + + # --- Compute six groups of logits (image/text × self/ori/cl) --- + logits_img_self = logit_scale_self * image_embed @ text_embed_all.t() + logits_txt_self = logit_scale_self * text_embed @ image_embed_all.t() + + logits_img_ori = logit_scale * image_embed @ text_embed_all_ori.t() + logits_txt_ori = logit_scale * text_embed @ image_embed_all_ori.t() + + logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() + logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() + + # --- Column mask: recon columns -> -inf (not as negatives) --- + clip_mask_all = self._gather_bool_mask(clip_mask) + col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) + + logits_img_self = logits_img_self.masked_fill(col_mask, float("-inf")) + logits_txt_self = logits_txt_self.masked_fill(col_mask, float("-inf")) + logits_img_ori = logits_img_ori.masked_fill(col_mask, float("-inf")) + logits_txt_ori = logits_txt_ori.masked_fill(col_mask, float("-inf")) + logits_img_cl = logits_img_cl.masked_fill(col_mask, float("-inf")) + logits_txt_cl = logits_txt_cl.masked_fill(col_mask, float("-inf")) + + # --- Safe labels: recon rows fallback to first clip column --- + labels = self.labels + fallback = clip_mask.long().argmax() # first clip sample index + safe_labels = torch.where(clip_mask, labels, fallback.expand_as(labels)) + + # --- Masked CE for three loss groups --- + loss_self = self._masked_cross_entropy( + logits_img_self, logits_txt_self, safe_labels, clip_mask + ) + loss_ori = self._masked_cross_entropy( + logits_img_ori, logits_txt_ori, safe_labels, clip_mask + ) + loss_cl = self._masked_cross_entropy( + logits_img_cl, logits_txt_cl, safe_labels, clip_mask + ) + + clip_loss = (loss_self + loss_ori + loss_cl) / 3 + + # Retrieval accuracy is diagnostic-only; skip the four argmax+eq+sum + # reductions during training (recover via the eval pass). + if self.training: + acc = torch.zeros((), device=clip_loss.device) + else: + with torch.no_grad(): + n_valid = clip_mask.float().sum().clamp(min=1) + correct = ( + (logits_img_self.argmax(-1).eq(safe_labels) & clip_mask).sum() + + (logits_txt_self.argmax(-1).eq(safe_labels) & clip_mask).sum() + + (logits_img_ori.argmax(-1).eq(safe_labels) & clip_mask).sum() + + (logits_txt_ori.argmax(-1).eq(safe_labels) & clip_mask).sum() + ) + acc = 100 * correct / (n_valid * 4) + + return { + "clip_loss": clip_loss, + "clip_acc": acc, + "loss_self": loss_self, + "loss_ori": loss_ori, + "loss_cl": loss_cl, + } diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py new file mode 100644 index 000000000..3751d45f7 --- /dev/null +++ b/tzrec/modules/sid_generation/kmeans.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""K-Means utilities for the SID-generation stack. + +This module is the single home for torch-native K-Means code used by +SID models: + +* :class:`KMeansLayer` — per-layer centroid container used by + :class:`ResidualKMeans` / :class:`RQKMeans`. Centroids are injected + by the FAISS backend via ``load_centroids_``; the only forward path + is ``predict``. +* :func:`_kmeans` / :func:`_residual_kmeans` — pure-torch Lloyd's + K-Means + residual variant, used by :class:`ResidualQuantized` to + warm-start the RQ-VAE codebook on the first training batch. They run + once on a single batch of encoder outputs (typically ~2k × 64), so + pulling in FAISS here would be all overhead and no benefit. +""" + +from typing import List, Tuple + +import torch +from torch import nn + + +def recon_diagnostics( + x: torch.Tensor, + out: torch.Tensor, + epsilon: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """MSE + relative-L1 reconstruction diagnostics. + + Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for + ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeans.train_offline`'s + per-layer log line (which converts to Python floats via ``.item()``). + + Args: + x: ground-truth embedding, shape (B, D). + out: quantized reconstruction, shape (B, D). + epsilon: numerical stabilizer for the relative-L1 denominator. + + Returns: + mse: scalar ``((out - x) ** 2).mean()``. + rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. + """ + mse = ((out - x) ** 2).mean() + rel = ( + torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) + ).mean() + return mse, rel + + +@torch.no_grad() +def _squared_euclidean_distance( + x: torch.Tensor, + y: torch.Tensor, + chunk_size: int = 50000, +) -> torch.Tensor: + """Squared L2 distance with chunked computation for memory efficiency. + + Chunks the rows of ``x`` so peak memory is bounded by + ``chunk_size * K * 4 bytes`` (fp32) regardless of ``N``. + + Args: + x (Tensor): data points, shape (N, D). + y (Tensor): centroids, shape (K, D). + chunk_size (int): max rows of x per chunk. Default: 50000. + + Returns: + Tensor: squared distances, shape (N, K). + """ + x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) + y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) + N = x.shape[0] + if N <= chunk_size: + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) + out = x.new_empty(N, y.shape[0]) + for start in range(0, N, chunk_size): + end = min(start + chunk_size, N) + out[start:end] = (x_sq[start:end] + y_sq - 2.0 * x[start:end] @ y.t()).clamp_( + min=0.0 + ) + return out + + +@torch.no_grad() +def _kmeans_plus_plus( + data: torch.Tensor, + n_clusters: int, +) -> torch.Tensor: + """KMeans++ initialization (Arthur & Vassilvitskii 2007). + + Selects initial centroids via distance-weighted probability sampling + to ensure well-spread starting points. Used by the RQ-VAE codebook + init path (``ResidualQuantized.kmeans_init``); RQKMeans itself no + longer needs it. + + Args: + data (Tensor): data points, shape (N, D). + n_clusters (int): number of clusters K. + + Returns: + Tensor: initial centroids, shape (K, D). + """ + N, D = data.shape + centroids = torch.zeros(n_clusters, D, device=data.device, dtype=data.dtype) + + idx = torch.randint(0, N, (1,), device=data.device) + centroids[0] = data[idx] + + for i in range(1, n_clusters): + dists = _squared_euclidean_distance(data, centroids[:i]) # (N, i) + min_dists = dists.min(dim=1)[0] # (N,) + if min_dists.sum() == 0: + centroids[i:] = data[ + torch.randint(0, N, (n_clusters - i,), device=data.device) + ] + break + next_idx = torch.multinomial(min_dists, num_samples=1) + centroids[i] = data[next_idx] + + return centroids + + +@torch.no_grad() +def _kmeans( + samples: torch.Tensor, + n_clusters: int, + n_iters: int = 100, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Lloyd's K-Means with KMeans++ initialization. + + Used by :class:`ResidualQuantized.init_embed_` to warm-start the + RQ-VAE codebook on the first training batch. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters (int): number of clusters K. + n_iters (int): number of Lloyd iterations. Default: 100. + + Returns: + centroids (Tensor): cluster centers, shape (K, D). + assignments (Tensor): cluster indices, shape (N,). + """ + N, D = samples.shape + centroids = _kmeans_plus_plus(samples, n_clusters) + + for _ in range(n_iters): + dists = _squared_euclidean_distance(samples, centroids) # (N, K) + assignments = dists.argmin(dim=-1) # (N,) + + bins = torch.bincount(assignments, minlength=n_clusters) + zero_mask = bins == 0 + bins_clamped = bins.masked_fill(zero_mask, 1) + + new_centroids = torch.zeros_like(centroids) + new_centroids.scatter_add_(0, assignments.unsqueeze(1).expand(-1, D), samples) + new_centroids = new_centroids / bins_clamped.unsqueeze(1) + + # Keep old centroids for empty clusters + centroids = torch.where(zero_mask.unsqueeze(1), centroids, new_centroids) + + return centroids, assignments + + +@torch.no_grad() +def _residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + n_iters: int = 100, +) -> List[torch.Tensor]: + """Residual K-Means: per-layer cluster then subtract centroids. + + Used by :class:`ResidualQuantized.init_embed_` to seed every RQ + codebook layer in one pass over the first training batch. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + n_iters (int): K-Means iterations per layer. Default: 100. + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), (K1, D), ...]``. + """ + res_centers = [] + for n_clusters in n_clusters_list: + centroids, assignments = _kmeans(samples, n_clusters, n_iters) + res_centers.append(centroids) + samples = samples - centroids[assignments] + return res_centers + + +class KMeansLayer(nn.Module): + """Single layer of a residual K-Means stack. + + Centroids are populated externally by ``load_centroids_`` (called per + layer by the FAISS backend in :class:`ResidualKMeans`); ``predict`` + is the only forward path. PyTorch state-dict keys are scoped by + attribute path (``layers..centroids``), so renaming the class + does not break existing checkpoints. + + Args: + n_clusters (int): number of clusters (codebook size). + n_features (int): feature dimension. + """ + + def __init__( + self, + n_clusters: int, + n_features: int, + ) -> None: + super().__init__() + self.n_clusters = n_clusters + self.n_features = n_features + + self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) + # Flipped by ``load_centroids_`` after the FAISS fit. Persistent + # so a normal post-fit checkpoint round-trips; mid-fit poisoning + # (True flag + still-zero centroids) is caught in _load_from_state_dict. + self.register_buffer("_is_initialized", torch.tensor(False)) + + @property + def is_initialized(self) -> bool: + """Whether centroids have been injected via ``load_centroids_``.""" + return self._is_initialized.item() + + @torch.no_grad() + def load_centroids_(self, centroids: torch.Tensor) -> None: + """Inject offline-trained centroids. + + Args: + centroids (Tensor): externally trained centroids, + shape (n_clusters, n_features). + """ + assert centroids.shape == self.centroids.shape, ( + f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " + f"got {tuple(centroids.shape)}" + ) + self.centroids.copy_( + centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) + ) + self._is_initialized.fill_(True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + """Reject mid-fit-checkpoint state dicts (True flag + zero centroids).""" + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: + error_msgs.append( + f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " + "are all zero — checkpoint was likely taken mid-FAISS-fit. " + "Re-run on_train_end to produce a valid checkpoint." + ) + + @torch.no_grad() + def predict(self, batch: torch.Tensor) -> torch.Tensor: + """Assign points to nearest centroid. + + Args: + batch (Tensor): data points, shape (B, D). + + Returns: + Tensor: cluster indices, shape (B,). + """ + dists = _squared_euclidean_distance(batch, self.centroids) + return torch.argmin(dists, dim=-1) diff --git a/tzrec/modules/sid_generation/residual_kmeans.py b/tzrec/modules/sid_generation/residual_kmeans.py new file mode 100644 index 000000000..60331030d --- /dev/null +++ b/tzrec/modules/sid_generation/residual_kmeans.py @@ -0,0 +1,373 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-layer residual K-Means: ResidualKMeans and RQKMeans wrapper. + +Training is FAISS-only: the codebook is built once via ``train_offline`` +over the full embedding matrix; ``forward`` is read-only (predict + lookup). +""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid_generation.kmeans import KMeansLayer, recon_diagnostics +from tzrec.utils.logging_util import logger + + +class ResidualKMeans(nn.Module): + """Multi-layer residual K-Means with offline FAISS training. + + Each layer quantizes the residual from the previous layer: + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i = layer_i.predict(residual_i) + quantized_i = layer_i.centroids[code_i] + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}) + + Args: + embed_dim (int): feature dimension. + n_layers (int): number of residual quantization layers. + n_embed (int|List[int]): number of clusters per layer. Default: 256. + All layers must share the same ``K`` — a single FAISS ``Kmeans`` + object is reused across layers (matches the OneRec reference). + Non-uniform codebooks are not supported. + normalize_residuals (bool): whether to L2-normalize residuals + before each layer. Default: False. + faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to + ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, + 'gpu': True, 'verbose': True, 'spherical': False}). + """ + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + normalize_residuals: bool = False, + faiss_kmeans_kwargs: Optional[Dict] = None, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.n_layers = n_layers + self.normalize_residuals = normalize_residuals + self.faiss_kmeans_kwargs = dict(faiss_kmeans_kwargs or {}) + + if isinstance(n_embed, int): + n_embed_list = [n_embed] * n_layers + else: + assert len(n_embed) == n_layers, ( + "length of n_embed and n_layers must be same, " + f"but got {len(n_embed)} vs {n_layers}" + ) + n_embed_list = list(n_embed) + # ``train_offline`` reuses a single ``faiss.Kmeans`` instance across + # layers, so non-uniform codebooks would silently train layers 1+ + # with ``K=n_embed_list[0]``. Fail fast instead. + assert len(set(n_embed_list)) == 1, ( + "ResidualKMeans / RQKMeans require a uniform codebook size " + f"across layers; got {n_embed_list}." + ) + self.n_embed_list = n_embed_list + + self.layers = nn.ModuleList( + [ + KMeansLayer( + n_clusters=n_embed_list[i], + n_features=embed_dim, + ) + for i in range(n_layers) + ] + ) + + @property + def all_initialized(self) -> bool: + """Whether all layers have been initialized via offline FAISS.""" + return all(layer.is_initialized for layer in self.layers) + + def output_dim(self) -> int: + """Output dimension of the module.""" + return self.embed_dim + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Assign codes per layer and sum the centroids. + + Codebook is read-only here; training happens in ``train_offline``. + Uninitialized layers return dummy zeros so the model is callable + before the one-shot FAISS fit completes. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + codes (Tensor): cluster indices per layer, shape (B, n_layers). + quantized (Tensor): sum of quantized embeddings, shape (B, D). + """ + residual = input + all_codes: List[torch.Tensor] = [] + quantized_sum = torch.zeros_like(input) + + for layer in self.layers: + if self.normalize_residuals: + residual = F.normalize(residual, dim=-1) + + if layer.is_initialized: + codes = layer.predict(residual) + quantized = layer.centroids[codes] + residual = residual - quantized + quantized_sum = quantized_sum + quantized + else: + codes = torch.zeros( + input.shape[0], dtype=torch.long, device=input.device + ) + all_codes.append(codes) + + cluster_ids = torch.stack(all_codes, dim=-1) # (B, n_layers) + return cluster_ids, quantized_sum + + @torch.no_grad() + def get_codes(self, input: torch.Tensor) -> torch.Tensor: + """Assign semantic IDs without updating centroids.""" + residual = input + all_codes: List[torch.Tensor] = [] + + for layer in self.layers: + if self.normalize_residuals: + residual = F.normalize(residual, dim=-1) + + codes = layer.predict(residual) + all_codes.append(codes) + quantized = layer.centroids[codes] + residual = residual - quantized + + return torch.stack(all_codes, dim=-1) + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get centroid weights for a specific layer. + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: centroids, shape (n_embed, embed_dim). + """ + return self.layers[layer_idx].centroids + + @torch.no_grad() + def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: + """Reconstruct embeddings from semantic ID codes. + + Args: + codes (Tensor): cluster ids, shape (B, n_layers). + + Returns: + Tensor: reconstructed embeddings, shape (B, D). + """ + quantized_sum = torch.zeros( + codes.shape[0], + self.embed_dim, + device=codes.device, + dtype=torch.float, + ) + for i, layer in enumerate(self.layers): + emb = layer.centroids[codes[:, i]] + quantized_sum = quantized_sum + emb + return quantized_sum + + @torch.no_grad() + def train_offline( + self, + inputs: Union[torch.Tensor, "np.ndarray"], + verbose: bool = True, + ) -> None: + """Train the multi-layer codebook via offline FAISS K-Means. + + Args: + inputs: full embedding matrix, shape (N, D). Either a + ``torch.Tensor`` (will be copied to numpy) or a + ``np.ndarray`` (ownership transferred; caller MUST + release any outside reference — the array is mutated + in-place to compute residuals layer by layer). + verbose (bool): whether to print per-layer reconstruction + loss. Default: True. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQKMeans training. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + # Materialise to a float32 contiguous numpy array that we own + # (so in-place residual updates are safe). + if isinstance(inputs, torch.Tensor): + assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + # Tensor path still requires a copy; caller will hold a + # reference until we return, so we must not alias it. + x = inputs.detach().cpu().float().numpy().copy() + else: + assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + # Numpy path: take ownership — no extra copy. Caller promises + # the array is no longer used outside. Only ensure dtype + # + contiguity (zero-copy when already satisfied). + x = np.ascontiguousarray(inputs, dtype=np.float32) + N, D = x.shape + out = np.zeros((N, D), dtype=np.float32) + + # Reuse one Kmeans instance across all layers (matches OneRec impl): + # rebuilding the FAISS object per layer doubles index-init cost. + n_embed = self.n_embed_list[0] + kmeans = faiss.Kmeans(self.embed_dim, n_embed, **self.faiss_kmeans_kwargs) + + # Chunk size for index.search to limit peak memory. + # 500K × 512 × 4B ≈ 1 GB per chunk. + SEARCH_CHUNK = 500_000 + + for layer_idx in range(self.n_layers): + if self.normalize_residuals: + norms = np.linalg.norm(x, axis=1, keepdims=True) + np.maximum(norms, 1e-8, out=norms) + x /= norms # in-place + + kmeans.train(x) + + for start in range(0, N, SEARCH_CHUNK): + end = min(start + SEARCH_CHUNK, N) + _, idx = kmeans.index.search(x[start:end], 1) + q = kmeans.centroids[idx.ravel()] # (chunk, D) + out[start:end] += q + x[start:end] -= q # residual + del idx, q + + if verbose: + out_t = torch.from_numpy(out) + ref_t = torch.from_numpy(out + x) # x_in = out + residual + logger.info( + "[ResidualKMeans][offline_faiss][layer %d] %s", + layer_idx, + self._calc_loss(ref_t, out_t), + ) + del out_t, ref_t + + centroids_t = torch.from_numpy(kmeans.centroids.copy()) + self.layers[layer_idx].load_centroids_(centroids_t) + if verbose: + logger.info( + "[ResidualKMeans][offline_faiss] layer %d finished", + layer_idx, + ) + + @staticmethod + def _calc_loss( + x: torch.Tensor, out: torch.Tensor, epsilon: float = 1e-4 + ) -> Dict[str, float]: + """Reconstruction loss diagnostics (MSE + relative L1).""" + loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) + return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} + + +class RQKMeans(nn.Module): + """RQ-KMeans: multi-layer residual K-Means trained offline via FAISS. + + No Encoder/Decoder — directly clusters input vectors via residual + K-Means. Codebook is built once by :meth:`train_offline`; ``forward`` + is read-only (assign + lookup). + + Args: + embed_dim (int): feature dimension. Default: 64. + n_layers (int): number of residual quantization layers. Default: 3. + n_embed (int|List[int]): number of clusters per layer. Default: 256. + normalize_residuals (bool): L2-normalize residuals before each + layer. Default: False. + faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to + ``faiss.Kmeans(...)``. + """ + + def __init__( + self, + embed_dim: int = 64, + n_layers: int = 3, + n_embed: Union[int, List[int]] = 256, + normalize_residuals: bool = False, + faiss_kmeans_kwargs: Optional[Dict] = None, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.quantizer = ResidualKMeans( + embed_dim=embed_dim, + n_layers=n_layers, + n_embed=n_embed, + normalize_residuals=normalize_residuals, + faiss_kmeans_kwargs=faiss_kmeans_kwargs, + ) + + def train_offline( + self, + inputs: Union[torch.Tensor, "np.ndarray"], + verbose: bool = True, + ) -> None: + """Build codebook offline via FAISS. + + Args: + inputs: full embedding matrix, shape (N, embed_dim). Either + a ``torch.Tensor`` or an ``np.ndarray`` (ownership + transferred — array is mutated in-place). + verbose (bool): print per-layer reconstruction loss. + """ + self.quantizer.train_offline(inputs, verbose=verbose) + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + """Forward: residual K-Means assignment (no training). + + Args: + x: (B, embed_dim) input features. + + Returns: + dict with keys: + 'codes': (B, n_layers) semantic IDs. + 'quantized': (B, embed_dim) quantized vector (sum of centroids). + """ + codes, quantized = self.quantizer(x) + return { + "codes": codes, + "quantized": quantized, + } + + @torch.no_grad() + def get_codes(self, x: torch.Tensor) -> torch.Tensor: + """Inference: get semantic IDs.""" + return self.quantizer.get_codes(x) + + @torch.no_grad() + def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: + """Reconstruct vectors from semantic IDs (centroid lookup + sum).""" + return self.quantizer.decode_codes(codes) + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get centroid weights for a specific layer.""" + return self.quantizer.get_codebook_embeddings(layer_idx) diff --git a/tzrec/modules/sid_generation/residual_quantized.py b/tzrec/modules/sid_generation/residual_quantized.py new file mode 100644 index 000000000..92a22250b --- /dev/null +++ b/tzrec/modules/sid_generation/residual_quantized.py @@ -0,0 +1,399 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResidualQuantized: multi-layer residual vector quantization with VQ layers.""" + +from typing import List, Sequence, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid_generation.kmeans import _residual_kmeans +from tzrec.modules.sid_generation.types import ( + QuantizeForwardMode, + ResidualQuantizedOutput, +) +from tzrec.modules.sid_generation.vector_quantize import VectorQuantize +from tzrec.utils.logging_util import logger + + +class ResidualQuantized(nn.Module): + """Multi-layer residual vector quantization. + + Each layer quantizes the residual from the previous layer: + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i, quantized_i = quantize(residual_i) + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}) + + Args: + embed_dim (int): dimension of input embeddings. + n_layers (int): number of quantization layers. + n_embed (int|List[int]): codebook size per layer. Default: 256. + forward_mode (str): VQ forward mode ('ste'|'gumbel_softmax'). + Default: 'ste'. + normalize_residuals (bool): L2-normalize residuals before each + quantization layer. Default: False. + distance_type (str|List[str]): distance metric per layer, + 'l2' or 'cosine'. Supports per-layer list. Default: 'l2'. + commitment_loss (str): commitment loss type, 'l2', 'l1' or 'cos'. + Default: 'l2'. + latent_weight (List[float]): commitment loss weights [w1, w2]. + w1: x toward quant (encoder side). + w2: quant toward x (codebook side). + Default: [1.0, 0.5]. + rotation_trick (bool): use rotation trick for improved STE + gradient estimation (arXiv:2410.06424). Default: False. + kmeans_init (bool): use residual K-Means codebook initialization + on first forward. Default: False. + use_sinkhorn (bool): Sinkhorn uniform assignment. Default: True. + sinkhorn_iters (int): Sinkhorn iterations. Default: 5. + sinkhorn_epsilon (float): Sinkhorn sharpness. Default: 10.0. + """ + + _FORWARD_MODE_MAP = { + "gumbel_softmax": QuantizeForwardMode.GUMBEL_SOFTMAX, + "ste": QuantizeForwardMode.STE, + } + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + forward_mode: str = "ste", + normalize_residuals: bool = False, + distance_type: Union[str, List[str]] = "l2", + commitment_loss: str = "l2", + latent_weight: Sequence[float] = (1.0, 0.5), + rotation_trick: bool = False, + kmeans_init: bool = False, + use_sinkhorn: bool = True, + sinkhorn_iters: int = 5, + sinkhorn_epsilon: float = 10.0, + ) -> None: + super().__init__() + assert commitment_loss in ("l2", "l1", "cos"), ( + f"commitment_loss must be 'l2', 'l1' or 'cos', got {commitment_loss!r}" + ) + self.embed_dim = embed_dim + self.n_layers = n_layers + self.normalize_residuals = normalize_residuals + self.commitment_loss_type = commitment_loss + self.rotation_trick = rotation_trick + + self.commitment_w1, self.commitment_w2 = latent_weight + + # ``initted`` is the kmeans_init guard: True means "codebook has + # been seeded", so init_embed_() becomes a no-op on later forwards. + self.register_buffer("initted", torch.tensor([not kmeans_init])) + + if forward_mode not in self._FORWARD_MODE_MAP: + raise ValueError( + f"Unsupported forward_mode '{forward_mode}', " + f"choose from {list(self._FORWARD_MODE_MAP.keys())}" + ) + mode_enum = self._FORWARD_MODE_MAP[forward_mode] + + if isinstance(n_embed, int): + n_embed_list = [n_embed] * n_layers + else: + assert len(n_embed) == n_layers, ( + "length of n_embed and n_layers must be same, " + f"but got {len(n_embed)} vs {n_layers}" + ) + n_embed_list = list(n_embed) + self.n_embed_list = n_embed_list + + if isinstance(distance_type, str): + distance_types = [distance_type] * n_layers + else: + assert len(distance_type) == n_layers, ( + "length of distance_type and n_layers must be same, " + f"but got {len(distance_type)} vs {n_layers}" + ) + distance_types = list(distance_type) + + self.layers = nn.ModuleList( + [ + VectorQuantize( + embed_dim=embed_dim, + n_embed=n_embed_list[i], + forward_mode=mode_enum, + distance_type=distance_types[i], + use_sinkhorn=use_sinkhorn, + sinkhorn_iters=sinkhorn_iters, + sinkhorn_epsilon=sinkhorn_epsilon, + ) + for i in range(n_layers) + ] + ) + + logger.info( + "ResidualQuantized init: embed_dim=%d, n_layers=%d, " + "n_embed=%s, forward_mode=%s, normalize_residuals=%s, " + "distance_type=%s, commitment_loss=%s, latent_weight=%s, " + "rotation_trick=%s, kmeans_init=%s, use_sinkhorn=%s, " + "sinkhorn_iters=%d, sinkhorn_epsilon=%s", + embed_dim, + n_layers, + n_embed, + forward_mode, + normalize_residuals, + distance_type, + commitment_loss, + list(latent_weight), + rotation_trick, + kmeans_init, + use_sinkhorn, + sinkhorn_iters, + sinkhorn_epsilon, + ) + + @torch.jit.ignore + @torch.no_grad() + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize codebook weights via residual K-Means. + + Only executed once when kmeans_init=True and not yet initialized. + Uses the first batch of training data as initialization pool. + + Args: + data (Tensor): input data, shape (B, D). + """ + if self.initted: + return + + centers = _residual_kmeans(data, self.n_embed_list) + + # Average per-layer centroids across DDP ranks so every rank + # starts from the same codebook. + if dist.is_initialized() and dist.get_world_size() > 1: + for c in centers: + dist.all_reduce(c, op=dist.ReduceOp.SUM) + c /= dist.get_world_size() + + for i, layer in enumerate(self.layers): + layer.embedding.weight.data.copy_(centers[i]) + + self.initted.fill_(True) + + def _single_commitment_loss( + self, + x: torch.Tensor, + quant: torch.Tensor, + ) -> torch.Tensor: + """Commitment loss for a single cumulative quantization tensor. + + - cos: (1 - cosine_similarity) * weight + - l2: (x - quant)^2.mean() * weight + - l1: |x - quant|.mean() * weight + + Both directions are always summed: + loss1 = encoder-toward-quant (gradient flows into encoder) + loss2 = quant-toward-encoder (gradient flows into codebook) + + Args: + x (Tensor): original input, shape (B, D). + quant (Tensor): cumulative quantized output at one layer, + shape (B, D). + + Returns: + Tensor: scalar commitment loss for this layer. + """ + if self.commitment_loss_type == "cos": + loss1 = ( + 1 - F.cosine_similarity(x, quant.detach(), dim=-1) + ).mean() * self.commitment_w1 + loss2 = ( + 1 - F.cosine_similarity(x.detach(), quant, dim=-1) + ).mean() * self.commitment_w2 + elif self.commitment_loss_type == "l1": + loss1 = (x - quant.detach()).abs().mean() * self.commitment_w1 + loss2 = (x.detach() - quant).abs().mean() * self.commitment_w2 + else: # 'l2' + loss1 = (x - quant.detach()).pow(2.0).mean() * self.commitment_w1 + loss2 = (x.detach() - quant).pow(2.0).mean() * self.commitment_w2 + return loss1 + loss2 + + @staticmethod + def _apply_rotation_trick( + x: torch.Tensor, + quant: torch.Tensor, + ) -> torch.Tensor: + """Apply rotation trick for improved STE gradient estimation. + + Implements equation 4.2 from https://arxiv.org/abs/2410.06424. + Replaces standard STE with a Householder reflection that rotates + the gradient direction from x toward quant. + + Args: + x (Tensor): original input with gradient, shape (B, D). + quant (Tensor): quantized output (will be detached), + shape (B, D). + + Returns: + Tensor: rotated output with gradient flowing through x. + """ + quant_detached = quant.detach() + x_detached = x.detach() + + quant_norms = torch.linalg.vector_norm(quant_detached, dim=-1).unsqueeze( + 1 + ) # (B, 1) + x_norms = torch.linalg.vector_norm(x_detached, dim=-1).unsqueeze(1) # (B, 1) + lambda_ = quant_norms / (x_norms + 1e-8) # (B, 1) + + x_hat = x_detached / (x_norms + 1e-8) # (B, D) + quant_hat = quant_detached / (quant_norms + 1e-8) # (B, D) + + normalized_sum = F.normalize(x_hat + quant_hat, p=2, dim=1) # (B, D) + + x_unsq = x.unsqueeze(1) # (B, 1, D) + + # Eq 4.2: Householder reflection + sum_projection = ( + x_unsq @ normalized_sum.unsqueeze(2) @ normalized_sum.unsqueeze(1) + ) # (B, 1, D) + rescaled_embeddings = ( + x_unsq @ x_hat.unsqueeze(2) @ quant_hat.unsqueeze(1) + ) # (B, 1, D) + return lambda_ * ( + x_unsq - 2 * sum_projection + 2 * rescaled_embeddings + ).squeeze(1) + + def output_dim(self) -> int: + """Output dimension of the module.""" + return self.embed_dim + + def forward( + self, + input: torch.Tensor, + temperature: float = 1.0, + ) -> ResidualQuantizedOutput: + """Forward the multi-layer residual quantization. + + Training flow: + 1. If kmeans_init and not initialized -> init_embed_(input) + 2. For each layer: quantize detached residual, accumulate + into aggregated_quants and compute per-layer commitment loss + in-place (avoids storing a quant_list of clones). + 3. Mean of per-layer commitment losses (cos/l2 with latent_weight) + 4. STE gradient pass-through (or rotation trick) + + Args: + input (Tensor): input embeddings, shape (B, D). + temperature (float): temperature for Gumbel-Softmax. + + Returns: + ResidualQuantizedOutput: (cluster_ids, quantized_embeddings, + quantization_loss). + """ + # Step 1: KMeans initialization (first training forward only) + if self.training: + self.init_embed_(input) + + # Detach residual for VQ assignment (gradient flows via STE only). + residual = input.detach() + all_ids: List[torch.Tensor] = [] + commitment_loss_list: List[torch.Tensor] = [] + aggregated_quants = torch.zeros_like(input) + + # Step 2: per-layer residual quantization + for layer in self.layers: + if self.normalize_residuals: + residual = F.normalize(residual, dim=-1) + + quantized = layer(residual, temperature=temperature) + all_ids.append(quantized.ids) + + # Separate raw lookup: ``quantized.embeddings`` already applies + # STE (gradient -> encoder), but the commitment loss + residual + # update need the un-STE'd codebook vector with gradient still + # flowing into ``layer.embedding.weight``. + raw_emb = layer.embedding(quantized.ids) + residual = residual - raw_emb.detach() + aggregated_quants = aggregated_quants + raw_emb + + commitment_loss_list.append( + self._single_commitment_loss(input, aggregated_quants) + ) + + cluster_ids = torch.stack(all_ids, dim=-1) # (B, n_layers) + + # Step 3: aggregate per-layer commitment loss + commitment_loss = torch.mean(torch.stack(commitment_loss_list)) + + # Step 4: STE or rotation trick (quants_trunc = final accumulated) + quants_trunc = aggregated_quants + if self.training: + if self.rotation_trick: + quants_trunc = self._apply_rotation_trick(input, quants_trunc) + else: + quants_trunc = input + (quants_trunc - input).detach() + + return ResidualQuantizedOutput( + cluster_ids=cluster_ids, + quantized_embeddings=quants_trunc, + quantization_loss=commitment_loss, + ) + + @torch.no_grad() + def get_codes(self, input: torch.Tensor) -> torch.Tensor: + """Assign semantic IDs without gradient computation. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + Tensor: cluster ids, shape (B, n_layers). + """ + output = self.forward(input) + return output.cluster_ids + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get codebook embedding weights for a specific layer. + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: codebook weights, shape (n_embed, embed_dim). + """ + return self.layers[layer_idx].embedding.weight.data + + @torch.no_grad() + def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: + """Reconstruct embeddings from semantic ID codes. + + Args: + codes (Tensor): cluster ids, shape (B, n_layers). + + Returns: + Tensor: reconstructed embeddings, shape (B, D). + """ + quantized_sum = torch.zeros( + codes.shape[0], + self.embed_dim, + device=codes.device, + dtype=torch.float, + ) + for i, layer in enumerate(self.layers): + emb = layer.embedding(codes[:, i]) + quantized_sum = quantized_sum + emb + return quantized_sum diff --git a/tzrec/modules/sid_generation/rqvae.py b/tzrec/modules/sid_generation/rqvae.py new file mode 100644 index 000000000..2bbc969dd --- /dev/null +++ b/tzrec/modules/sid_generation/rqvae.py @@ -0,0 +1,372 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RQVAE: Encoder + ResidualQuantized + Decoder top-level wrapper.""" + +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid_generation.clip_loss import MaskedCLIPLoss +from tzrec.modules.sid_generation.residual_quantized import ResidualQuantized +from tzrec.utils.logging_util import logger + + +class RQVAE(nn.Module): + """RQ-VAE: Encoder + ResidualQuantized + Decoder. + + Supports optional CLIP contrastive learning. When use_clip=True, + forward accepts paired inputs (fea1, fea2) and computes CLIP loss + via a siamese network (shared parameters). + + Encoder/Decoder are configurable-depth MLPs built via hidden_dims: + Encoder: input_dim -> hidden_dims[0] -> ... -> hidden_dims[-1] -> embed_dim + Decoder: embed_dim -> hidden_dims[-1] -> ... -> hidden_dims[0] -> input_dim + ReLU activation between hidden layers. Decoder reverses hidden_dims + for symmetric structure. + + Args: + input_dim (int): original embedding dimension. Default: 512. + embed_dim (int): latent space dimension. Default: 64. + hidden_dims (List[int]): encoder hidden layer dimensions. + Decoder automatically reverses for symmetry. + Default: [input_dim // 2]. + n_layers (int): number of residual quantization layers. Default: 3. + n_embed (int|List[int]): codebook size per layer. Default: 256. + forward_mode (str): VQ forward mode ('ste'|'gumbel_softmax'). + Default: 'ste'. + normalize_residuals (bool): L2-normalize residuals. Default: False. + distance_type (str|List[str]): distance metric ('l2'|'cosine'). + Default: 'l2'. + commitment_loss (str|None): commitment loss type ('l2'|'cos'). + Default: follows loss_type (al_sid behavior). + latent_weight (List[float]): commitment loss weights [w1, w2]. + Default: [1.0, 0.5]. + rotation_trick (bool): STE rotation trick. Default: False. + kmeans_init (bool): KMeans codebook initialization. Default: True. + use_sinkhorn (bool): Sinkhorn uniform assignment. Default: True. + sinkhorn_iters (int): Sinkhorn iterations. Default: 5. + sinkhorn_epsilon (float): Sinkhorn sharpness. Default: 10.0. + loss_type (str): reconstruction loss ('mse'|'l1'|'cosine'). + Default: 'mse'. + use_clip (bool): enable CLIP contrastive learning. Default: False. + """ + + @staticmethod + def _build_mlp(dims: List[int]) -> nn.Sequential: + """Build MLP: dims[0] -> ... -> dims[-1], ReLU between hidden layers.""" + layers: List[nn.Module] = [] + for i in range(len(dims) - 1): + layers.append(nn.Linear(dims[i], dims[i + 1])) + if i < len(dims) - 2: # no activation after last layer + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + + def __init__( + self, + input_dim: int = 512, + embed_dim: int = 64, + hidden_dims: Optional[List[int]] = None, + n_layers: int = 3, + n_embed: Union[int, List[int]] = 256, + forward_mode: str = "ste", + normalize_residuals: bool = False, + distance_type: Union[str, List[str]] = "l2", + commitment_loss: Optional[str] = None, + latent_weight: Sequence[float] = (1.0, 0.5), + rotation_trick: bool = False, + kmeans_init: bool = True, + use_sinkhorn: bool = True, + sinkhorn_iters: int = 5, + sinkhorn_epsilon: float = 10.0, + loss_type: str = "mse", + use_clip: bool = False, + ) -> None: + super().__init__() + + assert loss_type in ("mse", "l1", "cosine"), ( + f"loss_type must be 'mse', 'l1' or 'cosine', got '{loss_type}'" + ) + self.loss_type = loss_type + self.use_clip = use_clip + self.input_dim = input_dim + self.embed_dim = embed_dim + + self._is_inference = False + + if hidden_dims is None: + hidden_dims = [input_dim // 2] + + # commitment_loss defaults to follow loss_type (al_sid behavior: + # commitment_loss=loss_type, so mse -> l2 branch) + if commitment_loss is None: + commitment_loss = "l2" if loss_type == "mse" else loss_type + + enc_dims = [input_dim] + list(hidden_dims) + [embed_dim] + self.encoder = self._build_mlp(enc_dims) + + # Decoder is the symmetric reverse of the encoder. + dec_dims = [embed_dim] + list(reversed(hidden_dims)) + [input_dim] + self.decoder = self._build_mlp(dec_dims) + + self.quantizer = ResidualQuantized( + embed_dim=embed_dim, + n_layers=n_layers, + n_embed=n_embed, + forward_mode=forward_mode, + normalize_residuals=normalize_residuals, + distance_type=distance_type, + commitment_loss=commitment_loss, + latent_weight=latent_weight, + rotation_trick=rotation_trick, + kmeans_init=kmeans_init, + use_sinkhorn=use_sinkhorn, + sinkhorn_iters=sinkhorn_iters, + sinkhorn_epsilon=sinkhorn_epsilon, + ) + + # CLIP contrastive learning (optional) + if use_clip: + self.logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.masked_clip_loss_fn = MaskedCLIPLoss() + + logger.info( + "RQVAE init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " + "n_layers=%d, n_embed=%s, forward_mode=%s, " + "normalize_residuals=%s, distance_type=%s, " + "commitment_loss=%s, latent_weight=%s, rotation_trick=%s, " + "kmeans_init=%s, use_sinkhorn=%s, " + "sinkhorn_iters=%d, sinkhorn_epsilon=%s, " + "loss_type=%s, use_clip=%s", + input_dim, + embed_dim, + hidden_dims, + n_layers, + n_embed, + forward_mode, + normalize_residuals, + distance_type, + commitment_loss, + list(latent_weight), + rotation_trick, + kmeans_init, + use_sinkhorn, + sinkhorn_iters, + sinkhorn_epsilon, + loss_type, + use_clip, + ) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode. (B, input_dim) -> (B, embed_dim).""" + return self.encoder(x) + + def decode(self, z_q: torch.Tensor) -> torch.Tensor: + """Decode. (B, embed_dim) -> (B, input_dim).""" + return self.decoder(z_q) + + def _cosine_loss(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """Cosine distance loss: 1 - mean(cos_sim).""" + return (1 - F.cosine_similarity(x1, x2, dim=1)).mean() + + def compute_loss( + self, + x: torch.Tensor, + x_hat: torch.Tensor, + quant_loss: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Compute reconstruction loss + quantization loss + total loss. + + loss_total = recon_loss + quant_loss + Note: al_sid latent_loss_weight is declared but unused; + commitment_loss is added 1:1 with recon_loss. We align with this. + + Args: + x: original input, shape (B, input_dim). + x_hat: reconstructed output, shape (B, input_dim). + quant_loss: quantization (commitment) loss scalar. + + Returns: + dict with 'reconstruction_loss', 'quantization_loss', 'loss'. + """ + if self.loss_type == "mse": + recon_loss = F.mse_loss(x_hat, x, reduction="mean") + elif self.loss_type == "l1": + recon_loss = F.l1_loss(x_hat, x, reduction="mean") + elif self.loss_type == "cosine": + recon_loss = self._cosine_loss(x_hat, x) + else: + raise ValueError(f"Unsupported loss_type: {self.loss_type}") + + loss_total = recon_loss + quant_loss + + return { + "reconstruction_loss": recon_loss, + "quantization_loss": quant_loss, + "loss": loss_total, + } + + def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + """Dispatch based on use_clip. + + use_clip=False: forward(x) -> forward_rqvae(x) + use_clip=True: forward(fea1, fea2, clip_mask) -> forward_mixed(...) + """ + if self._is_inference or not self.use_clip: + assert len(args) >= 1, "Standard mode requires (x,)" + return self.forward_rqvae(args[0], **kwargs) + else: + assert len(args) == 3, "Mixed mode requires (fea1, fea2, clip_mask)" + return self.forward_mixed(args[0], args[1], args[2], **kwargs) + + def forward_rqvae( + self, x: torch.Tensor, temperature: float = 1.0 + ) -> Dict[str, torch.Tensor]: + """Standard RQ-VAE forward: encode -> quantize -> decode -> loss. + + Args: + x: (B, input_dim) original embedding. + temperature: Gumbel-Softmax temperature. + + Returns: + dict with keys: 'x_hat', 'codes', 'quantized', + 'reconstruction_loss', 'quantization_loss', 'loss'. + """ + z_e = self.encode(x) + quant_output = self.quantizer(z_e, temperature=temperature) + x_hat = self.decode(quant_output.quantized_embeddings) + + losses = self.compute_loss(x, x_hat, quant_output.quantization_loss) + + return { + "x_hat": x_hat, + "codes": quant_output.cluster_ids, + "quantized": quant_output.quantized_embeddings, + **losses, + } + + def _compute_masked_recon_loss( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + recon_mask: torch.Tensor, + ) -> torch.Tensor: + """Compute per-sample recon loss, masked to recon rows only. + + No boolean indexing, no data-dependent branching, + compatible with torch.compile. + + Args: + x_hat: (B, D) reconstructed output. + x: (B, D) original input. + recon_mask: (B,) bool, True = recon row. + """ + if self.loss_type == "mse": + per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) + elif self.loss_type == "l1": + per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + elif self.loss_type == "cosine": + per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) + else: + raise ValueError(f"Unsupported loss_type: {self.loss_type}") + n_recon = recon_mask.float().sum().clamp(min=1) + return (per_sample * recon_mask.float()).sum() / n_recon + + def forward_mixed( + self, + fea1: torch.Tensor, + fea2: torch.Tensor, + clip_mask: torch.Tensor, + temperature: float = 1.0, + ) -> Dict[str, torch.Tensor]: + """Mixed recon + CLIP forward. + + All samples go through dual paths; mask separates recon and clip + loss contributions. + + Args: + fea1: (B, input_dim) main embedding (all rows valid). + fea2: (B, input_dim) clip embedding (recon rows == fea1). + clip_mask: (B,) bool, True = clip sample. + temperature: Gumbel-Softmax temperature. + """ + # Step 1: dual-path encode -> quantize -> decode + z_e1 = self.encode(fea1) + quant1 = self.quantizer(z_e1, temperature=temperature) + x_hat1 = self.decode(quant1.quantized_embeddings) + + z_e2 = self.encode(fea2) + quant2 = self.quantizer(z_e2, temperature=temperature) + x_hat2 = self.decode(quant2.quantized_embeddings) + + # Step 2: recon loss (only recon rows, no branching) + recon_mask = ~clip_mask + recon_loss = self._compute_masked_recon_loss(x_hat1, fea1, recon_mask) + + # Step 3: masked CLIP loss (only clip rows) + features = { + "image_embed": x_hat1, + "text_embed": x_hat2, + "image_embed_ori": fea1, + "text_embed_ori": fea2, + "logit_scale_self": self.logit_scale_self.exp(), + "logit_scale_cl": self.logit_scale_cl.exp(), + "logit_scale": self.logit_scale.exp(), + } + clip_result = self.masked_clip_loss_fn(features, clip_mask) + + # Step 4: commitment loss (average of two paths) + commitment = (quant1.quantization_loss + quant2.quantization_loss) / 2 + + return { + "codes": quant1.cluster_ids, + "quantized": quant1.quantized_embeddings, + "x_hat": x_hat1, + "recon_loss": recon_loss, + "clip_loss": clip_result["clip_loss"], + "clip_acc": clip_result["clip_acc"], + "loss_self": clip_result["loss_self"], + "loss_ori": clip_result["loss_ori"], + "loss_cl": clip_result["loss_cl"], + "commitment_loss": commitment, + "loss": recon_loss + clip_result["clip_loss"] + commitment, + } + + @torch.no_grad() + def get_codes(self, x: torch.Tensor) -> torch.Tensor: + """Inference: get semantic IDs. + + Args: + x: (B, input_dim) original embedding. + + Returns: + Tensor: codes, shape (B, n_layers). + """ + z_e = self.encode(x) + return self.quantizer.get_codes(z_e) + + @torch.no_grad() + def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: + """Reconstruct embedding from semantic IDs (through decoder). + + Args: + codes: (B, n_layers) semantic ID codes. + + Returns: + Tensor: x_hat, shape (B, input_dim). + """ + z_q = self.quantizer.decode_codes(codes) + return self.decode(z_q) diff --git a/tzrec/modules/sid_generation/types.py b/tzrec/modules/sid_generation/types.py new file mode 100644 index 000000000..e0596e3c0 --- /dev/null +++ b/tzrec/modules/sid_generation/types.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data types for SID generation: enums and output tuples.""" + +from enum import Enum +from typing import NamedTuple + +import torch + + +class QuantizeForwardMode(Enum): + """Forward mode for vector quantization. + + Attributes: + GUMBEL_SOFTMAX: use Gumbel-Softmax reparameterization. + STE: use Straight-Through Estimator. + """ + + GUMBEL_SOFTMAX = 1 + STE = 2 + + +class QuantizeOutput(NamedTuple): + """Output of a single vector quantization layer. + + Attributes: + embeddings (Tensor): quantized embeddings, shape (B, D). + ids (Tensor): codebook indices, shape (B,). + """ + + embeddings: torch.Tensor + ids: torch.Tensor + + +class ResidualQuantizedOutput(NamedTuple): + """Output of the residual quantization module. + + Attributes: + cluster_ids (Tensor): codebook indices per layer, shape (B, n_layers). + quantized_embeddings (Tensor): sum of quantized embeddings, shape (B, D). + quantization_loss (Tensor): total commitment loss scalar. + """ + + cluster_ids: torch.Tensor + quantized_embeddings: torch.Tensor + quantization_loss: torch.Tensor diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid_generation/vector_quantize.py new file mode 100644 index 000000000..d4955f2db --- /dev/null +++ b/tzrec/modules/sid_generation/vector_quantize.py @@ -0,0 +1,264 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single codebook vector quantization with Sinkhorn uniform assignment.""" + +from typing import Tuple + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid_generation.kmeans import _squared_euclidean_distance +from tzrec.modules.sid_generation.types import ( + QuantizeForwardMode, + QuantizeOutput, +) + + +def _gumbel_softmax_sample( + logits: torch.Tensor, + temperature: float = 1.0, + hard: bool = True, +) -> torch.Tensor: + """Sample from the Gumbel-Softmax distribution. + + Args: + logits (Tensor): un-normalized log probabilities, shape (B, N). + temperature (float): temperature for Gumbel-Softmax. + hard (bool): if True, return one-hot with straight-through gradient. + + Returns: + Tensor: soft or hard sample, shape (B, N). + """ + return F.gumbel_softmax(logits, tau=temperature, hard=hard, dim=-1) + + +@torch.no_grad() +def _sinkhorn( + cost: torch.Tensor, + n_iters: int = 5, + epsilon: float = 10.0, + is_distributed: bool = True, +) -> torch.Tensor: + """Sinkhorn-Knopp algorithm for optimal-transport based uniform assignment. + + Transforms a distance matrix into a soft assignment matrix via exponential + kernel and alternating row-column normalization, approximating a doubly + stochastic matrix to ensure uniform codebook utilization. + + Args: + cost (Tensor): distance matrix, shape (B, K) where K is codebook size. + IMPORTANT: must be z-score normalized and shifted to non-negative + before calling this function to avoid numerical overflow. + n_iters (int): number of Sinkhorn iterations. Default: 5. + epsilon (float): sharpness parameter for exp(-cost * epsilon). + Larger values produce sharper assignments. Default: 10.0. + is_distributed (bool): whether running in distributed mode. + If True, row sums are all_reduced across GPUs. Default: True. + + Returns: + Tensor: assignment matrix, shape (B, K). + Use Q.argmax(dim=-1) externally to get codebook indices. + """ + # Step 1: exponential kernel transform (B, K) -> (K, B) + Q = torch.exp(-cost * epsilon).t() + + # Global batch size for distributed training + if is_distributed and dist.is_initialized(): + B = Q.size(1) * dist.get_world_size() + else: + B = Q.size(1) + K = Q.size(0) + + # Step 2: global normalization — make matrix sum to 1 + sum_Q = torch.sum(Q) + if is_distributed and dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + 1e-8 + + # Step 3: alternating row-column normalization + for _ in range(n_iters): + # Row normalization: each prototype's total weight = 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if is_distributed and dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + 1e-8 + Q /= K + + # Column normalization: each sample's total weight = 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + 1e-8 + Q /= B + + # Step 4: scale back so columns sum to 1 (assignment) + Q *= B + return Q.t() # (B, K) + + +class VectorQuantize(nn.Module): + """Single codebook vector quantization layer. + + Maps continuous input vectors to the nearest codebook entry and returns + the quantized embeddings + codebook indices. The commitment loss is + computed at the residual-aggregator level by + :meth:`ResidualQuantized._single_commitment_loss` over the cumulative + quants (matching al_sid's ``RQBottleneck.compute_commitment_loss``); + this layer is intentionally loss-free. + + During training, Sinkhorn optimal-transport assignment is optionally + used to encourage uniform codebook utilization. + + Args: + embed_dim (int): dimension of each codebook embedding. + n_embed (int): number of codebook entries. + forward_mode (QuantizeForwardMode): quantization forward mode, + either GUMBEL_SOFTMAX or STE. Default: STE. + distance_type (str): distance metric, 'l2' or 'cosine'. + Default: 'l2'. + use_sinkhorn (bool): whether to use Sinkhorn uniform assignment + during training. Default: True. + sinkhorn_iters (int): number of Sinkhorn iterations. Default: 5. + sinkhorn_epsilon (float): Sinkhorn sharpness parameter for + exp(-cost * epsilon). Default: 10.0. + """ + + def __init__( + self, + embed_dim: int, + n_embed: int, + forward_mode: QuantizeForwardMode = QuantizeForwardMode.STE, + distance_type: str = "l2", + use_sinkhorn: bool = True, + sinkhorn_iters: int = 5, + sinkhorn_epsilon: float = 10.0, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.forward_mode = forward_mode + self.distance_type = distance_type + self.use_sinkhorn = use_sinkhorn + self.sinkhorn_iters = sinkhorn_iters + self.sinkhorn_epsilon = sinkhorn_epsilon + + self.embedding = nn.Embedding(n_embed, embed_dim) + nn.init.kaiming_uniform_(self.embedding.weight) + + @torch.no_grad() + def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: + """Compute distances between input vectors and codebook entries. + + Supports L2 and cosine distance metrics. + + Args: + x (Tensor): input vectors, shape (B, D). + + Returns: + Tensor: pairwise distances, shape (B, n_embed). + """ + codebook = self.embedding.weight # (n_embed, D) + + if self.distance_type == "l2": + distances = _squared_euclidean_distance(x, codebook) + elif self.distance_type == "cosine": + x_norm = F.normalize(x, p=2, dim=1) + codebook_norm = F.normalize(codebook, p=2, dim=1) + distances = -torch.matmul(x_norm, codebook_norm.t()) + else: + raise ValueError( + f"Unsupported distance_type '{self.distance_type}', " + f"choose from ('l2', 'cosine')" + ) + return distances + + @torch.no_grad() + def _find_nearest_embedding( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Find nearest codebook entry for each input vector. + + During training with use_sinkhorn=True, applies z-score + normalization + non-negative shift before Sinkhorn assignment. + Otherwise falls back to argmin. + + Args: + x (Tensor): input vectors, shape (B, D). + + Returns: + ids (Tensor): codebook indices, shape (B,). + distances (Tensor): distance matrix, shape (B, n_embed). + """ + distances = self._compute_distances(x) # (B, n_embed) + + if self.training and self.use_sinkhorn: + # Sinkhorn requires non-negative cost; z-score then shift. + var, mean = torch.var_mean(distances, unbiased=False) + distances = (distances - mean) * var.add(1e-12).rsqrt() + distances = distances - distances.min() + + # Sinkhorn optimal-transport assignment + Q = _sinkhorn( + distances, + n_iters=self.sinkhorn_iters, + epsilon=self.sinkhorn_epsilon, + is_distributed=dist.is_initialized(), + ) + ids = Q.argmax(dim=-1) + else: + ids = distances.argmin(dim=-1) + + return ids, distances + + def forward( + self, + x: torch.Tensor, + temperature: float = 1.0, + ) -> QuantizeOutput: + """Forward the vector quantization layer. + + Training flow: + 1. compute distances (L2 or cosine) + 2. if use_sinkhorn: z-score normalize + Sinkhorn -> argmax + else: argmin + 3. compute differentiable embedding (STE or Gumbel-Softmax) + + Commitment loss is computed by the caller + (:meth:`ResidualQuantized._single_commitment_loss`). + + Args: + x (Tensor): input vectors, shape (B, D). + temperature (float): temperature for Gumbel-Softmax. + + Returns: + QuantizeOutput: named tuple of (embeddings, ids). + """ + # Step 1-2: find nearest codebook entry + ids, distances = self._find_nearest_embedding(x) + + # Step 3: differentiable embedding. Gumbel takes a separate path + # that combines all codebook entries; STE goes through a single + # embedding lookup. + if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: + weights = _gumbel_softmax_sample( + -distances, temperature=temperature, hard=True + ) + emb = weights @ self.embedding.weight + elif self.training and self.forward_mode == QuantizeForwardMode.STE: + quantized = self.embedding(ids) + # Straight-Through Estimator: gradient passes through + emb = x + (quantized - x).detach() + elif self.training: + raise ValueError(f"Unsupported forward mode: {self.forward_mode}") + else: + emb = self.embedding(ids) + + return QuantizeOutput(embeddings=emb, ids=ids) diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index bef2062ea..d2c34ae0f 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -5,6 +5,7 @@ import "tzrec/protos/models/rank_model.proto"; import "tzrec/protos/models/multi_task_rank.proto"; import "tzrec/protos/models/match_model.proto"; import "tzrec/protos/models/general_rank_model.proto"; +import "tzrec/protos/models/sid_model.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; @@ -76,6 +77,10 @@ message ModelConfig { TDM tdm = 400; RocketLaunching rocket_launching = 500; + + // SID generation models + SidRqvae sid_rqvae = 600; + SidRqkmeans sid_rqkmeans = 601; } optional uint32 num_class = 2 [default = 1]; diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto new file mode 100644 index 000000000..41513b51c --- /dev/null +++ b/tzrec/protos/models/sid_model.proto @@ -0,0 +1,94 @@ +syntax = "proto2"; +package tzrec.protos; + +import "google/protobuf/struct.proto"; + +message SinkhornConfig { + // Number of Sinkhorn iterations. + optional uint32 iters = 1 [default = 5]; + // Sinkhorn sharpness parameter (epsilon). + optional float epsilon = 2 [default = 10.0]; + // Whether Sinkhorn uniform assignment is enabled. Default: true. + // Set ``enabled: false`` to turn Sinkhorn off while still providing + // the sub-config (e.g. to override iters/epsilon for an A/B comparison). + optional bool enabled = 3 [default = true]; +} + +message ClipConfig { + // Name of the second feature (paired with embedding_feature_name + // to form a contrastive pair). + required string clip_feature_name = 1; + // Name of the per-sample boolean feature (0/1, value_dim=1) that + // flags whether the row is a CLIP pair (1) or a reconstruction-only + // row (0). Required for mixed recon+clip batches: the model uses + // this column directly as the ``clip_mask``. Replaces the prior + // bit-exact ``embedding == fea2`` discrimination, which silently + // mislabeled rows on any upstream float cast / normalization. + required string is_clip_pair_feature_name = 2; +} + +message SidRqvae { + // === Network structure === + // Input embedding dimension. + optional uint32 input_dim = 1 [default = 512]; + // Quantization latent dimension (encoder output / codebook dim). + optional uint32 embed_dim = 2 [default = 64]; + // Encoder hidden layer sizes, comma-separated, e.g. "256,128". + // Defaults to [input_dim // 2] when unset. + optional string hidden_dims = 3; + // Per-layer codebook size, comma-separated, e.g. "256,256,256". + // List length is the number of residual quantization layers; + // non-uniform codebooks such as "512,256,128" are supported. + optional string codebook = 5; + + // === Quantization strategy === + // VQ forward mode: "ste" or "gumbel_softmax". + optional string forward_mode = 6 [default = "ste"]; + // L2-normalize residuals before each quantization layer. + optional bool normalize_residuals = 7 [default = false]; + // Distance metric: "l2" or "cosine". + optional string distance_type = 9 [default = "l2"]; + // Commitment loss type: "l2", "l1" or "cos". + optional string commitment_loss = 10 [default = "l2"]; + // Commitment loss weights [w1, w2], comma-separated. + optional string latent_weight = 11 [default = "1.0,0.5"]; + // STE rotation trick. + optional bool rotation_trick = 12 [default = false]; + // KMeans codebook initialization on first training forward. + optional bool kmeans_init = 13 [default = true]; + + // === Optional sub-module configs === + // Sinkhorn uniform assignment. Default behavior when this block is + // omitted: enabled with iters=5, epsilon=10.0. Include the sub-config + // to override params; set ``enabled: false`` inside it to disable. + optional SinkhornConfig sinkhorn_config = 15; + // CLIP contrastive learning (disabled when unset). + optional ClipConfig clip_config = 16; + + // Reconstruction loss type: "mse", "l1", or "cosine". + optional string loss_type = 20 [default = "mse"]; + + // Name of the item embedding feature inside the input Batch. + optional string embedding_feature_name = 40 [default = "item_emb"]; +} + +message SidRqkmeans { + // Input embedding dimension (K-Means runs directly on raw embeddings, + // no encoder). + optional uint32 input_dim = 1 [default = 512]; + // Per-layer cluster counts, comma-separated, e.g. "256,256,256". + // List length is the number of residual quantization layers. All + // entries must be equal — the FAISS backend reuses a single + // ``faiss.Kmeans`` object across layers, so non-uniform codebooks + // are not supported (a uniformity assert fires at construction). + optional string codebook = 3; + // L2-normalize residuals before each layer. + optional bool normalize_residuals = 4 [default = true]; + // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a + // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, + // spherical: false, seed: 1234}. + optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + + // Name of the item embedding feature inside the input Batch. + optional string embedding_feature_name = 40 [default = "item_emb"]; +} From 1f494e6a448f61cc4f5dc733e2614cdd41c2e2a8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 02:17:22 +0000 Subject: [PATCH 002/129] [refactor] SID proto: use repeated numeric fields, drop _sid_helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (sid_model.proto:54), replace the stringly-typed list fields with the tzrec-conventional repeated numeric types: hidden_dims / codebook : string -> repeated uint32 latent_weight : string -> repeated float This moves validation to proto-load time, restores text_format type checking, and removes the ad-hoc tzrec/models/_sid_helpers.py shim (parse_int_list / parse_float_list). It also fixes the always-truthy `if cfg.latent_weight:` guard noted in review — an unset repeated field is an empty (falsy) list, so the signature-default branch is now real. Wrappers/tests updated to pass lists; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/_sid_helpers.py | 24 ------------------------ tzrec/models/sid_rqkmeans.py | 5 ++--- tzrec/models/sid_rqkmeans_test.py | 4 ++-- tzrec/models/sid_rqvae.py | 14 +++++++------- tzrec/models/sid_rqvae_test.py | 8 ++++---- tzrec/protos/models/sid_model.proto | 19 ++++++++++--------- 6 files changed, 25 insertions(+), 49 deletions(-) delete mode 100644 tzrec/models/_sid_helpers.py diff --git a/tzrec/models/_sid_helpers.py b/tzrec/models/_sid_helpers.py deleted file mode 100644 index 04946003c..000000000 --- a/tzrec/models/_sid_helpers.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2024, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Shared helpers for SID-generation model wrappers.""" - -from typing import List - - -def parse_int_list(s: str) -> List[int]: - """Parse comma-separated int string, e.g. '256,128' -> [256, 128].""" - return [int(x.strip()) for x in s.split(",") if x.strip()] - - -def parse_float_list(s: str) -> List[float]: - """Parse comma-separated float string, e.g. '1.0,0.5' -> [1.0, 0.5].""" - return [float(x.strip()) for x in s.split(",") if x.strip()] diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 398e266cc..e69956a2b 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -28,7 +28,6 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature -from tzrec.models._sid_helpers import parse_int_list from tzrec.models.model import BaseModel from tzrec.modules.sid_generation import RQKMeans from tzrec.modules.sid_generation.kmeans import recon_diagnostics @@ -80,8 +79,8 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message self._embedding_feature_name = cfg.embedding_feature_name - assert cfg.codebook, "codebook must be set, e.g. '256,256,256'" - n_embed_list = parse_int_list(cfg.codebook) + assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" + n_embed_list = list(cfg.codebook) n_layers = len(n_embed_list) self._faiss_kwargs = ( diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 25b4f0800..2beb9da1a 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -40,14 +40,14 @@ def _create_model(self, input_dim=32, n_layers=2, niter=5): """Create a SidRqkmeans configured for offline FAISS fit.""" from google.protobuf.struct_pb2 import Struct - n_embed_str = ",".join(["16"] * n_layers) + n_embed_list = [16] * n_layers faiss_kwargs = Struct() faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) sid_rqkmeans_cfg = sid_model_pb2.SidRqkmeans( input_dim=input_dim, - codebook=n_embed_str, + codebook=n_embed_list, normalize_residuals=False, faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 0520c8f23..632f03ef4 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -23,7 +23,6 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature -from tzrec.models._sid_helpers import parse_float_list, parse_int_list from tzrec.models.model import BaseModel from tzrec.modules.sid_generation import RQVAE from tzrec.protos.model_pb2 import ModelConfig @@ -63,15 +62,16 @@ def __init__( cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None ) - hidden_dims = parse_int_list(cfg.hidden_dims) if cfg.hidden_dims else None - # Only forward latent_weight when proto sets it; otherwise let - # RQVAE / ResidualQuantized apply their signature default (1.0, 0.5). + hidden_dims = list(cfg.hidden_dims) if cfg.hidden_dims else None + # Only forward latent_weight when the user set it (repeated field is + # empty when unset); otherwise let RQVAE / ResidualVectorQuantizer + # apply their signature default (1.0, 0.5). rqvae_extra: Dict[str, Any] = {} if cfg.latent_weight: - rqvae_extra["latent_weight"] = parse_float_list(cfg.latent_weight) + rqvae_extra["latent_weight"] = list(cfg.latent_weight) - assert cfg.codebook, "codebook must be set, e.g. '256,256,256'" - n_embed_list = parse_int_list(cfg.codebook) + assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" + n_embed_list = list(cfg.codebook) n_layers = len(n_embed_list) use_sinkhorn = True diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index a87e58133..518c78d33 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -47,11 +47,11 @@ class SidRqvaeTest(unittest.TestCase): def _create_model(self, use_clip=False, input_dim=32, embed_dim=8, n_layers=2): """Helper to create a SidRqvae model with minimal config.""" - n_embed_str = ",".join(["16"] * n_layers) + n_embed_list = [16] * n_layers sid_rqvae_cfg = sid_model_pb2.SidRqvae( input_dim=input_dim, embed_dim=embed_dim, - codebook=n_embed_str, + codebook=n_embed_list, forward_mode="ste", loss_type="mse", kmeans_init=False, @@ -346,11 +346,11 @@ def test_sinkhorn_config_enabled_false(self) -> None: Previously ``use_sinkhorn`` was hard-coded ``True`` and the proto block was honored only for iters/epsilon. """ - n_embed_str = ",".join(["16"] * 2) + n_embed_list = [16] * 2 sid_rqvae_cfg = sid_model_pb2.SidRqvae( input_dim=32, embed_dim=8, - codebook=n_embed_str, + codebook=n_embed_list, forward_mode="ste", loss_type="mse", kmeans_init=False, diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 41513b51c..0385d5728 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -33,13 +33,13 @@ message SidRqvae { optional uint32 input_dim = 1 [default = 512]; // Quantization latent dimension (encoder output / codebook dim). optional uint32 embed_dim = 2 [default = 64]; - // Encoder hidden layer sizes, comma-separated, e.g. "256,128". + // Encoder hidden layer sizes, e.g. [256, 128]. // Defaults to [input_dim // 2] when unset. - optional string hidden_dims = 3; - // Per-layer codebook size, comma-separated, e.g. "256,256,256". + repeated uint32 hidden_dims = 3; + // Per-layer codebook size, e.g. [256, 256, 256]. // List length is the number of residual quantization layers; - // non-uniform codebooks such as "512,256,128" are supported. - optional string codebook = 5; + // non-uniform codebooks such as [512, 256, 128] are supported. + repeated uint32 codebook = 5; // === Quantization strategy === // VQ forward mode: "ste" or "gumbel_softmax". @@ -50,8 +50,9 @@ message SidRqvae { optional string distance_type = 9 [default = "l2"]; // Commitment loss type: "l2", "l1" or "cos". optional string commitment_loss = 10 [default = "l2"]; - // Commitment loss weights [w1, w2], comma-separated. - optional string latent_weight = 11 [default = "1.0,0.5"]; + // Commitment loss weights [w1, w2]. Defaults to [1.0, 0.5] when unset + // (applied by RQVAE / ResidualVectorQuantizer). + repeated float latent_weight = 11; // STE rotation trick. optional bool rotation_trick = 12 [default = false]; // KMeans codebook initialization on first training forward. @@ -76,12 +77,12 @@ message SidRqkmeans { // Input embedding dimension (K-Means runs directly on raw embeddings, // no encoder). optional uint32 input_dim = 1 [default = 512]; - // Per-layer cluster counts, comma-separated, e.g. "256,256,256". + // Per-layer cluster counts, e.g. [256, 256, 256]. // List length is the number of residual quantization layers. All // entries must be equal — the FAISS backend reuses a single // ``faiss.Kmeans`` object across layers, so non-uniform codebooks // are not supported (a uniformity assert fires at construction). - optional string codebook = 3; + repeated uint32 codebook = 3; // L2-normalize residuals before each layer. optional bool normalize_residuals = 4 [default = true]; // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a From ebdbd3481c9d7c04acfe219917aec79d73565af7 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 02:23:56 +0000 Subject: [PATCH 003/129] [refactor] SID modules: introduce ResidualQuantizer abstract base + renames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (residual_kmeans.py:29, residual_quantized.py:30), abstract the two residual-quantization backends behind a shared base and align names with the tzrec module convention: - new modules/sid_generation/residual_quantizer.py: ResidualQuantizer abstract base — owns embed_dim/n_layers/n_embed_list/normalize_residuals, the backend-agnostic decode_codes (via a _lookup_code primitive) and output_dim. Subclasses build self.layers and implement forward/get_codes/ get_codebook_embeddings/_lookup_code. - residual_quantized.py -> residual_vector_quantizer.py ResidualQuantized -> ResidualVectorQuantizer (VQ, gradient-trained) - residual_kmeans.py -> residual_kmeans_quantizer.py ResidualKMeans -> ResidualKMeansQuantizer (offline FAISS) - types.ResidualQuantizedOutput -> ResidualQuantizerOutput - rqvae.py / __init__.py / tests / docstrings updated to the new names. Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae_test.py | 14 +- tzrec/modules/sid_generation/__init__.py | 20 ++- tzrec/modules/sid_generation/kmeans.py | 14 +- ...kmeans.py => residual_kmeans_quantizer.py} | 60 ++----- .../sid_generation/residual_quantizer.py | 150 ++++++++++++++++++ ...ntized.py => residual_vector_quantizer.py} | 59 ++----- tzrec/modules/sid_generation/rqvae.py | 10 +- tzrec/modules/sid_generation/types.py | 2 +- .../modules/sid_generation/vector_quantize.py | 4 +- 9 files changed, 212 insertions(+), 121 deletions(-) rename tzrec/modules/sid_generation/{residual_kmeans.py => residual_kmeans_quantizer.py} (87%) create mode 100644 tzrec/modules/sid_generation/residual_quantizer.py rename tzrec/modules/sid_generation/{residual_quantized.py => residual_vector_quantizer.py} (88%) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 518c78d33..c44042c5f 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -315,12 +315,12 @@ def test_commitment_loss_l1_branch(self) -> None: Previously ``"l1"`` silently fell through to the L2 branch. """ - from tzrec.modules.sid_generation.residual_quantized import ( - ResidualQuantized, + from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, ) torch.manual_seed(0) - rq = ResidualQuantized( + rq = ResidualVectorQuantizer( embed_dim=8, n_layers=2, n_embed=4, @@ -386,13 +386,13 @@ def test_sinkhorn_config_default_enabled(self) -> None: self.assertTrue(layer.use_sinkhorn) def test_commitment_loss_invalid_raises(self) -> None: - """ResidualQuantized rejects unknown commitment_loss spellings.""" - from tzrec.modules.sid_generation.residual_quantized import ( - ResidualQuantized, + """ResidualVectorQuantizer rejects unknown commitment_loss spellings.""" + from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, ) with self.assertRaisesRegex(AssertionError, "commitment_loss"): - ResidualQuantized( + ResidualVectorQuantizer( embed_dim=8, n_layers=2, n_embed=4, diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index 4d4a5d5f2..d9a414556 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -15,12 +15,15 @@ from tzrec.modules.sid_generation.kmeans import ( KMeansLayer, ) -from tzrec.modules.sid_generation.residual_kmeans import ( - ResidualKMeans, +from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, RQKMeans, ) -from tzrec.modules.sid_generation.residual_quantized import ( - ResidualQuantized, +from tzrec.modules.sid_generation.residual_quantizer import ( + ResidualQuantizer, +) +from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, ) from tzrec.modules.sid_generation.rqvae import ( RQVAE, @@ -28,7 +31,7 @@ from tzrec.modules.sid_generation.types import ( QuantizeForwardMode, QuantizeOutput, - ResidualQuantizedOutput, + ResidualQuantizerOutput, ) from tzrec.modules.sid_generation.vector_quantize import ( VectorQuantize, @@ -37,12 +40,13 @@ __all__ = [ "QuantizeForwardMode", "QuantizeOutput", - "ResidualQuantizedOutput", + "ResidualQuantizerOutput", "VectorQuantize", "GatherLayer", - "ResidualQuantized", + "ResidualQuantizer", + "ResidualVectorQuantizer", "RQVAE", "KMeansLayer", - "ResidualKMeans", + "ResidualKMeansQuantizer", "RQKMeans", ] diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py index 3751d45f7..1ebb1a64d 100644 --- a/tzrec/modules/sid_generation/kmeans.py +++ b/tzrec/modules/sid_generation/kmeans.py @@ -15,11 +15,11 @@ SID models: * :class:`KMeansLayer` — per-layer centroid container used by - :class:`ResidualKMeans` / :class:`RQKMeans`. Centroids are injected + :class:`ResidualKMeansQuantizer` / :class:`RQKMeans`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. * :func:`_kmeans` / :func:`_residual_kmeans` — pure-torch Lloyd's - K-Means + residual variant, used by :class:`ResidualQuantized` to + K-Means + residual variant, used by :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the first training batch. They run once on a single batch of encoder outputs (typically ~2k × 64), so pulling in FAISS here would be all overhead and no benefit. @@ -39,7 +39,7 @@ def recon_diagnostics( """MSE + relative-L1 reconstruction diagnostics. Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for - ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeans.train_offline`'s + ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s per-layer log line (which converts to Python floats via ``.item()``). Args: @@ -100,7 +100,7 @@ def _kmeans_plus_plus( Selects initial centroids via distance-weighted probability sampling to ensure well-spread starting points. Used by the RQ-VAE codebook - init path (``ResidualQuantized.kmeans_init``); RQKMeans itself no + init path (``ResidualVectorQuantizer.kmeans_init``); RQKMeans itself no longer needs it. Args: @@ -138,7 +138,7 @@ def _kmeans( ) -> Tuple[torch.Tensor, torch.Tensor]: """Lloyd's K-Means with KMeans++ initialization. - Used by :class:`ResidualQuantized.init_embed_` to warm-start the + Used by :class:`ResidualVectorQuantizer.init_embed_` to warm-start the RQ-VAE codebook on the first training batch. Args: @@ -179,7 +179,7 @@ def _residual_kmeans( ) -> List[torch.Tensor]: """Residual K-Means: per-layer cluster then subtract centroids. - Used by :class:`ResidualQuantized.init_embed_` to seed every RQ + Used by :class:`ResidualVectorQuantizer.init_embed_` to seed every RQ codebook layer in one pass over the first training batch. Args: @@ -202,7 +202,7 @@ class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. Centroids are populated externally by ``load_centroids_`` (called per - layer by the FAISS backend in :class:`ResidualKMeans`); ``predict`` + layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` is the only forward path. PyTorch state-dict keys are scoped by attribute path (``layers..centroids``), so renaming the class does not break existing checkpoints. diff --git a/tzrec/modules/sid_generation/residual_kmeans.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py similarity index 87% rename from tzrec/modules/sid_generation/residual_kmeans.py rename to tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 60331030d..cdd55c6a0 100644 --- a/tzrec/modules/sid_generation/residual_kmeans.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Multi-layer residual K-Means: ResidualKMeans and RQKMeans wrapper. +"""Multi-layer residual K-Means: ResidualKMeansQuantizer and RQKMeans wrapper. Training is FAISS-only: the codebook is built once via ``train_offline`` over the full embedding matrix; ``forward`` is read-only (predict + lookup). @@ -23,10 +23,11 @@ from torch.nn import functional as F from tzrec.modules.sid_generation.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid_generation.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger -class ResidualKMeans(nn.Module): +class ResidualKMeansQuantizer(ResidualQuantizer): """Multi-layer residual K-Means with offline FAISS training. Each layer quantizes the residual from the previous layer: @@ -62,33 +63,21 @@ def __init__( normalize_residuals: bool = False, faiss_kmeans_kwargs: Optional[Dict] = None, ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.n_layers = n_layers - self.normalize_residuals = normalize_residuals + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) self.faiss_kmeans_kwargs = dict(faiss_kmeans_kwargs or {}) - if isinstance(n_embed, int): - n_embed_list = [n_embed] * n_layers - else: - assert len(n_embed) == n_layers, ( - "length of n_embed and n_layers must be same, " - f"but got {len(n_embed)} vs {n_layers}" - ) - n_embed_list = list(n_embed) # ``train_offline`` reuses a single ``faiss.Kmeans`` instance across # layers, so non-uniform codebooks would silently train layers 1+ # with ``K=n_embed_list[0]``. Fail fast instead. - assert len(set(n_embed_list)) == 1, ( - "ResidualKMeans / RQKMeans require a uniform codebook size " - f"across layers; got {n_embed_list}." + assert len(set(self.n_embed_list)) == 1, ( + "ResidualKMeansQuantizer / RQKMeans require a uniform codebook " + f"size across layers; got {self.n_embed_list}." ) - self.n_embed_list = n_embed_list self.layers = nn.ModuleList( [ KMeansLayer( - n_clusters=n_embed_list[i], + n_clusters=self.n_embed_list[i], n_features=embed_dim, ) for i in range(n_layers) @@ -100,10 +89,6 @@ def all_initialized(self) -> bool: """Whether all layers have been initialized via offline FAISS.""" return all(layer.is_initialized for layer in self.layers) - def output_dim(self) -> int: - """Output dimension of the module.""" - return self.embed_dim - def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Assign codes per layer and sum the centroids. @@ -169,26 +154,9 @@ def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """ return self.layers[layer_idx].centroids - @torch.no_grad() - def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: - """Reconstruct embeddings from semantic ID codes. - - Args: - codes (Tensor): cluster ids, shape (B, n_layers). - - Returns: - Tensor: reconstructed embeddings, shape (B, D). - """ - quantized_sum = torch.zeros( - codes.shape[0], - self.embed_dim, - device=codes.device, - dtype=torch.float, - ) - for i, layer in enumerate(self.layers): - emb = layer.centroids[codes[:, i]] - quantized_sum = quantized_sum + emb - return quantized_sum + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up codebook vectors via the layer's centroid table.""" + return self.layers[layer_idx].centroids[code_idx] @torch.no_grad() def train_offline( @@ -267,7 +235,7 @@ def train_offline( out_t = torch.from_numpy(out) ref_t = torch.from_numpy(out + x) # x_in = out + residual logger.info( - "[ResidualKMeans][offline_faiss][layer %d] %s", + "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, self._calc_loss(ref_t, out_t), ) @@ -277,7 +245,7 @@ def train_offline( self.layers[layer_idx].load_centroids_(centroids_t) if verbose: logger.info( - "[ResidualKMeans][offline_faiss] layer %d finished", + "[ResidualKMeansQuantizer][offline_faiss] layer %d finished", layer_idx, ) @@ -317,7 +285,7 @@ def __init__( ) -> None: super().__init__() self.embed_dim = embed_dim - self.quantizer = ResidualKMeans( + self.quantizer = ResidualKMeansQuantizer( embed_dim=embed_dim, n_layers=n_layers, n_embed=n_embed, diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid_generation/residual_quantizer.py new file mode 100644 index 000000000..ff238bdc7 --- /dev/null +++ b/tzrec/modules/sid_generation/residual_quantizer.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResidualQuantizer: abstract base for multi-layer residual quantizers.""" + +from typing import List, Union + +import torch +from torch import nn + + +def normalize_n_embed(n_embed: Union[int, List[int]], n_layers: int) -> List[int]: + """Broadcast a scalar codebook size to a per-layer list (or validate one). + + Args: + n_embed (int|List[int]): codebook size, shared or per-layer. + n_layers (int): number of residual quantization layers. + + Returns: + List[int]: per-layer codebook sizes, length ``n_layers``. + """ + if isinstance(n_embed, int): + return [n_embed] * n_layers + assert len(n_embed) == n_layers, ( + "length of n_embed and n_layers must be same, " + f"but got {len(n_embed)} vs {n_layers}" + ) + return list(n_embed) + + +class ResidualQuantizer(nn.Module): + """Abstract base for multi-layer residual quantization. + + Shared contract for the two SID quantizer backends — the VQ-based, + gradient-trained :class:`ResidualVectorQuantizer` and the K-Means-based, + offline-FAISS-trained :class:`ResidualKMeansQuantizer`. Both quantize the + residual of the previous layer: + + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i, quantized_i = layer_i(residual_i) + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}). + + This base owns the structural invariants (``embed_dim``, ``n_layers``, + per-layer codebook sizes, residual normalization toggle) and the + backend-agnostic :meth:`decode_codes` / :meth:`output_dim`. Subclasses + build ``self.layers`` and implement :meth:`forward`, :meth:`get_codes`, + :meth:`get_codebook_embeddings`, and :meth:`_lookup_code`. + + Args: + embed_dim (int): feature / codebook dimension. + n_layers (int): number of residual quantization layers. + n_embed (int|List[int]): codebook size per layer. Default: 256. + normalize_residuals (bool): L2-normalize residuals before each + layer. Default: False. + """ + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + normalize_residuals: bool = False, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.n_layers = n_layers + self.normalize_residuals = normalize_residuals + self.n_embed_list = normalize_n_embed(n_embed, n_layers) + # Subclasses MUST populate this with one quantization layer each. + self.layers: nn.ModuleList = nn.ModuleList() + + def output_dim(self) -> int: + """Output dimension of the module.""" + return self.embed_dim + + def forward(self, input: torch.Tensor): # noqa: ANN201 + """Assign codes per layer and accumulate the quantized output.""" + raise NotImplementedError + + @torch.no_grad() + def get_codes(self, input: torch.Tensor) -> torch.Tensor: + """Assign semantic IDs without updating the codebook. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + Tensor: cluster ids, shape (B, n_layers). + """ + raise NotImplementedError + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get the codebook (centroid) weights for a specific layer. + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: codebook weights, shape (n_embed, embed_dim). + """ + raise NotImplementedError + + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up the codebook vectors for ``code_idx`` at ``layer_idx``. + + The single backend-specific primitive :meth:`decode_codes` builds on + (VQ reads ``embedding(idx)``, K-Means reads ``centroids[idx]``). + + Args: + layer_idx (int): index of the quantization layer. + code_idx (Tensor): codebook indices, shape (B,). + + Returns: + Tensor: looked-up codebook vectors, shape (B, embed_dim). + """ + raise NotImplementedError + + @torch.no_grad() + def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: + """Reconstruct embeddings from semantic ID codes (centroid sum). + + Args: + codes (Tensor): cluster ids, shape (B, n_layers). + + Returns: + Tensor: reconstructed embeddings, shape (B, embed_dim). + """ + quantized_sum = torch.zeros( + codes.shape[0], + self.embed_dim, + device=codes.device, + dtype=torch.float, + ) + for i in range(self.n_layers): + quantized_sum = quantized_sum + self._lookup_code(i, codes[:, i]) + return quantized_sum diff --git a/tzrec/modules/sid_generation/residual_quantized.py b/tzrec/modules/sid_generation/residual_vector_quantizer.py similarity index 88% rename from tzrec/modules/sid_generation/residual_quantized.py rename to tzrec/modules/sid_generation/residual_vector_quantizer.py index 92a22250b..c5ca829f0 100644 --- a/tzrec/modules/sid_generation/residual_quantized.py +++ b/tzrec/modules/sid_generation/residual_vector_quantizer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ResidualQuantized: multi-layer residual vector quantization with VQ layers.""" +"""ResidualVectorQuantizer: multi-layer residual VQ with gradient training.""" from typing import List, Sequence, Union @@ -19,15 +19,16 @@ from torch.nn import functional as F from tzrec.modules.sid_generation.kmeans import _residual_kmeans +from tzrec.modules.sid_generation.residual_quantizer import ResidualQuantizer from tzrec.modules.sid_generation.types import ( QuantizeForwardMode, - ResidualQuantizedOutput, + ResidualQuantizerOutput, ) from tzrec.modules.sid_generation.vector_quantize import VectorQuantize from tzrec.utils.logging_util import logger -class ResidualQuantized(nn.Module): +class ResidualVectorQuantizer(ResidualQuantizer): """Multi-layer residual vector quantization. Each layer quantizes the residual from the previous layer: @@ -86,13 +87,10 @@ def __init__( sinkhorn_iters: int = 5, sinkhorn_epsilon: float = 10.0, ) -> None: - super().__init__() + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) assert commitment_loss in ("l2", "l1", "cos"), ( f"commitment_loss must be 'l2', 'l1' or 'cos', got {commitment_loss!r}" ) - self.embed_dim = embed_dim - self.n_layers = n_layers - self.normalize_residuals = normalize_residuals self.commitment_loss_type = commitment_loss self.rotation_trick = rotation_trick @@ -109,16 +107,6 @@ def __init__( ) mode_enum = self._FORWARD_MODE_MAP[forward_mode] - if isinstance(n_embed, int): - n_embed_list = [n_embed] * n_layers - else: - assert len(n_embed) == n_layers, ( - "length of n_embed and n_layers must be same, " - f"but got {len(n_embed)} vs {n_layers}" - ) - n_embed_list = list(n_embed) - self.n_embed_list = n_embed_list - if isinstance(distance_type, str): distance_types = [distance_type] * n_layers else: @@ -132,7 +120,7 @@ def __init__( [ VectorQuantize( embed_dim=embed_dim, - n_embed=n_embed_list[i], + n_embed=self.n_embed_list[i], forward_mode=mode_enum, distance_type=distance_types[i], use_sinkhorn=use_sinkhorn, @@ -144,7 +132,7 @@ def __init__( ) logger.info( - "ResidualQuantized init: embed_dim=%d, n_layers=%d, " + "ResidualVectorQuantizer init: embed_dim=%d, n_layers=%d, " "n_embed=%s, forward_mode=%s, normalize_residuals=%s, " "distance_type=%s, commitment_loss=%s, latent_weight=%s, " "rotation_trick=%s, kmeans_init=%s, use_sinkhorn=%s, " @@ -276,15 +264,11 @@ def _apply_rotation_trick( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) - def output_dim(self) -> int: - """Output dimension of the module.""" - return self.embed_dim - def forward( self, input: torch.Tensor, temperature: float = 1.0, - ) -> ResidualQuantizedOutput: + ) -> ResidualQuantizerOutput: """Forward the multi-layer residual quantization. Training flow: @@ -300,7 +284,7 @@ def forward( temperature (float): temperature for Gumbel-Softmax. Returns: - ResidualQuantizedOutput: (cluster_ids, quantized_embeddings, + ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, quantization_loss). """ # Step 1: KMeans initialization (first training forward only) @@ -346,7 +330,7 @@ def forward( else: quants_trunc = input + (quants_trunc - input).detach() - return ResidualQuantizedOutput( + return ResidualQuantizerOutput( cluster_ids=cluster_ids, quantized_embeddings=quants_trunc, quantization_loss=commitment_loss, @@ -377,23 +361,6 @@ def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """ return self.layers[layer_idx].embedding.weight.data - @torch.no_grad() - def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: - """Reconstruct embeddings from semantic ID codes. - - Args: - codes (Tensor): cluster ids, shape (B, n_layers). - - Returns: - Tensor: reconstructed embeddings, shape (B, D). - """ - quantized_sum = torch.zeros( - codes.shape[0], - self.embed_dim, - device=codes.device, - dtype=torch.float, - ) - for i, layer in enumerate(self.layers): - emb = layer.embedding(codes[:, i]) - quantized_sum = quantized_sum + emb - return quantized_sum + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up codebook vectors via the layer's embedding table.""" + return self.layers[layer_idx].embedding(code_idx) diff --git a/tzrec/modules/sid_generation/rqvae.py b/tzrec/modules/sid_generation/rqvae.py index 2bbc969dd..95e9d48e8 100644 --- a/tzrec/modules/sid_generation/rqvae.py +++ b/tzrec/modules/sid_generation/rqvae.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RQVAE: Encoder + ResidualQuantized + Decoder top-level wrapper.""" +"""RQVAE: Encoder + ResidualVectorQuantizer + Decoder top-level wrapper.""" from typing import Dict, List, Optional, Sequence, Union @@ -19,12 +19,14 @@ from torch.nn import functional as F from tzrec.modules.sid_generation.clip_loss import MaskedCLIPLoss -from tzrec.modules.sid_generation.residual_quantized import ResidualQuantized +from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, +) from tzrec.utils.logging_util import logger class RQVAE(nn.Module): - """RQ-VAE: Encoder + ResidualQuantized + Decoder. + """RQ-VAE: Encoder + ResidualVectorQuantizer + Decoder. Supports optional CLIP contrastive learning. When use_clip=True, forward accepts paired inputs (fea1, fea2) and computes CLIP loss @@ -120,7 +122,7 @@ def __init__( dec_dims = [embed_dim] + list(reversed(hidden_dims)) + [input_dim] self.decoder = self._build_mlp(dec_dims) - self.quantizer = ResidualQuantized( + self.quantizer = ResidualVectorQuantizer( embed_dim=embed_dim, n_layers=n_layers, n_embed=n_embed, diff --git a/tzrec/modules/sid_generation/types.py b/tzrec/modules/sid_generation/types.py index e0596e3c0..fce27486f 100644 --- a/tzrec/modules/sid_generation/types.py +++ b/tzrec/modules/sid_generation/types.py @@ -41,7 +41,7 @@ class QuantizeOutput(NamedTuple): ids: torch.Tensor -class ResidualQuantizedOutput(NamedTuple): +class ResidualQuantizerOutput(NamedTuple): """Output of the residual quantization module. Attributes: diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid_generation/vector_quantize.py index d4955f2db..9e043ffaf 100644 --- a/tzrec/modules/sid_generation/vector_quantize.py +++ b/tzrec/modules/sid_generation/vector_quantize.py @@ -110,7 +110,7 @@ class VectorQuantize(nn.Module): Maps continuous input vectors to the nearest codebook entry and returns the quantized embeddings + codebook indices. The commitment loss is computed at the residual-aggregator level by - :meth:`ResidualQuantized._single_commitment_loss` over the cumulative + :meth:`ResidualVectorQuantizer._single_commitment_loss` over the cumulative quants (matching al_sid's ``RQBottleneck.compute_commitment_loss``); this layer is intentionally loss-free. @@ -232,7 +232,7 @@ def forward( 3. compute differentiable embedding (STE or Gumbel-Softmax) Commitment loss is computed by the caller - (:meth:`ResidualQuantized._single_commitment_loss`). + (:meth:`ResidualVectorQuantizer._single_commitment_loss`). Args: x (Tensor): input vectors, shape (B, D). From 5c6f67a0d1e93ee0ef17ae95e5a10e40c16b9d3c Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 02:28:45 +0000 Subject: [PATCH 004/129] [refactor] SID models: add BaseSidModel parent for shared init/metric MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (sid_rqvae.py:211 — "add a parent class for rqkmeans and rqvae for shared init_metric and update_metric"), introduce tzrec/models/sid_model.py::BaseSidModel(BaseModel) and have both SidRqvae and SidRqkmeans inherit it. The base owns the structure the two models duplicated: - __init__ scaffolding: embedding_feature_name + codebook -> n_embed_list / n_layers. - _extract_feature(batch, feature_name=None) (replaces the per-model _extract_feature / _extract_embedding copies). - init_loss (SID losses are internal; no module to register). - init_metric registering the shared eval metrics (mse, unique_sid_ratio); subclasses call super().init_metric() then add extras (RQ-VAE: train-path mse; RQ-KMeans: rel_loss). - _update_unique_sid_ratio(codes) shared by both update_metric paths. - a default no-op update_train_metric (RQ-VAE overrides it). Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 118 +++++++++++++++++++++++++++++++++++ tzrec/models/sid_rqkmeans.py | 59 +++++------------- tzrec/models/sid_rqvae.py | 49 +++------------ 3 files changed, 140 insertions(+), 86 deletions(-) create mode 100644 tzrec/models/sid_model.py diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py new file mode 100644 index 000000000..cb0935186 --- /dev/null +++ b/tzrec/models/sid_model.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BaseSidModel: shared base for semantic-ID generation models.""" + +from typing import Any, List, Optional + +import torch +import torchmetrics + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.model import BaseModel +from tzrec.protos.model_pb2 import ModelConfig + + +class BaseSidModel(BaseModel): + """Shared base for semantic-ID (SID) generation models. + + Factors the structure common to :class:`SidRqvae` (RQ-VAE) and + :class:`SidRqkmeans` (residual K-Means): + + - reading the item-embedding feature out of ``Batch.dense_features``, + - parsing the per-layer ``codebook`` into ``n_embed_list`` / ``n_layers``, + - the eval metrics every SID model reports — reconstruction ``mse`` and + ``unique_sid_ratio`` (codebook coverage). + + Subclasses build their quantizer in ``__init__`` (after calling + ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. + They extend :meth:`init_metric` / :meth:`update_metric` with any + backend-specific metrics. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config + self._embedding_feature_name = cfg.embedding_feature_name + + assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" + self._n_embed_list = list(cfg.codebook) + self._n_layers = len(self._n_embed_list) + + def _extract_feature( + self, batch: Batch, feature_name: Optional[str] = None + ) -> torch.Tensor: + """Extract a named dense feature from ``Batch.dense_features``. + + Args: + batch (Batch): input batch data. + feature_name (str, optional): feature name to extract. + Defaults to ``self._embedding_feature_name``. + """ + if feature_name is None: + feature_name = self._embedding_feature_name + kt = batch.dense_features[BASE_DATA_GROUP] + return kt[feature_name] + + def init_loss(self) -> None: + """Initialize loss modules. + + SID models compute their losses internally and pass them through + ``predictions``; there is no external loss module to register. + """ + pass + + def init_metric(self) -> None: + """Initialize the eval metrics shared by all SID models. + + ``mse``: reconstruction error (input vs. quantized / decoded). + ``unique_sid_ratio``: codebook coverage = unique SIDs / batch size. + Subclasses call ``super().init_metric()`` then add their extras. + """ + self._metric_modules["mse"] = torchmetrics.MeanMetric() + self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() + + def update_train_metric( + self, + predictions: dict, + batch: Batch, + ) -> None: + """Update train-path metric state. + + Default is a no-op: K-Means has no train-time codes, so only models + with a meaningful train signal (RQ-VAE) override this. + """ + return + + def _update_unique_sid_ratio(self, codes: torch.Tensor) -> None: + """Update the codebook-coverage metric (unique SIDs / batch size). + + Args: + codes (Tensor): semantic-ID codes, shape (B, n_layers). + """ + B = codes.shape[0] + unique_sids = torch.unique(codes, dim=0).shape[0] + self._metric_modules["unique_sid_ratio"].update(unique_sids / B) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index e69956a2b..22916854a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -26,9 +26,9 @@ from google.protobuf.json_format import MessageToDict from torch import nn -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature -from tzrec.models.model import BaseModel +from tzrec.models.sid_model import BaseSidModel from tzrec.modules.sid_generation import RQKMeans from tzrec.modules.sid_generation.kmeans import recon_diagnostics from tzrec.protos.model_pb2 import ModelConfig @@ -52,7 +52,7 @@ def _coerce_proto_numbers(d: Dict) -> Dict: return out -class SidRqkmeans(BaseModel): +class SidRqkmeans(BaseSidModel): """SID generation model using residual K-Means (FAISS-only). No gradient-based training. The codebook is built once at the end @@ -77,11 +77,6 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config # SidRqkmeans proto message - self._embedding_feature_name = cfg.embedding_feature_name - - assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" - n_embed_list = list(cfg.codebook) - n_layers = len(n_embed_list) self._faiss_kwargs = ( _coerce_proto_numbers(MessageToDict(cfg.faiss_kmeans_kwargs)) @@ -91,8 +86,8 @@ def __init__( self._rqkmeans = RQKMeans( embed_dim=cfg.input_dim, - n_layers=n_layers, - n_embed=n_embed_list, + n_layers=self._n_layers, + n_embed=self._n_embed_list, normalize_residuals=cfg.normalize_residuals, faiss_kmeans_kwargs=self._faiss_kwargs, ) @@ -105,11 +100,6 @@ def __init__( # Add dummy param to keep optimizer/DDP happy. self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) - def _extract_embedding(self, batch: Batch) -> torch.Tensor: - """Extract item embedding from Batch.dense_features.""" - kt = batch.dense_features[BASE_DATA_GROUP] - return kt[self._embedding_feature_name] - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -122,7 +112,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: Return: predictions (dict): a dict of predicted result. """ - embedding = self._extract_embedding(batch) + embedding = self._extract_feature(batch) # Training: buffer for the end-of-loop FAISS fit and return dummy # codes — the codebook does not exist yet. @@ -135,10 +125,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: if self.is_train: self._offline_buffer.append(embedding.detach().cpu()) B = embedding.shape[0] - n_layers = self._rqkmeans.quantizer.n_layers return { "codes": torch.zeros( - B, n_layers, dtype=torch.long, device=embedding.device + B, self._n_layers, dtype=torch.long, device=embedding.device ) } @@ -154,14 +143,6 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: return predictions - def init_loss(self) -> None: - """Initialize loss modules. - - KMeans has no gradient loss; the codebook is built in - ``on_train_end`` at end of training. - """ - pass - def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: @@ -181,25 +162,17 @@ def loss( return {"dummy_loss": self._dummy_param.sum() * 0.0} def init_metric(self) -> None: - """Initialize metric modules. + """Initialize metric modules (shared eval metrics + rel_loss). Only eval metrics are registered. During training ``predict`` returns dummy zero codes (the codebook does not exist yet), so - any train-time metric would be either NaN or trivially constant. - ``compute_train_metric`` therefore returns an empty dict, which - the framework already tolerates. + any train-time metric would be either NaN or trivially constant; + the inherited no-op ``update_train_metric`` keeps the train path + empty (``compute_train_metric`` then returns an empty dict, which + the framework already tolerates). """ - self._metric_modules["mse"] = torchmetrics.MeanMetric() + super().init_metric() self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() - self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() - - def update_train_metric( - self, - predictions: Dict[str, torch.Tensor], - batch: Batch, - ) -> None: - """No-op — see :meth:`init_metric`.""" - return def update_metric( self, @@ -214,9 +187,6 @@ def update_metric( batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - codes = predictions["codes"] - B = codes.shape[0] - if "input_embedding" in predictions: mse, rel = recon_diagnostics( predictions["input_embedding"], @@ -225,8 +195,7 @@ def update_metric( self._metric_modules["mse"].update(mse) self._metric_modules["rel_loss"].update(rel) - unique_sids = torch.unique(codes, dim=0).shape[0] - self._metric_modules["unique_sid_ratio"].update(unique_sids / B) + self._update_unique_sid_ratio(predictions["codes"]) @torch.no_grad() def on_train_end(self) -> None: diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 632f03ef4..9c8c53590 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -21,14 +21,14 @@ import torch.nn.functional as F import torchmetrics -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature -from tzrec.models.model import BaseModel +from tzrec.models.sid_model import BaseSidModel from tzrec.modules.sid_generation import RQVAE from tzrec.protos.model_pb2 import ModelConfig -class SidRqvae(BaseModel): +class SidRqvae(BaseSidModel): """SID generation model using RQ-VAE (Encoder + VQ + Decoder). End-to-end differentiable training with reconstruction loss @@ -52,7 +52,6 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config # SidRqvae proto message - self._embedding_feature_name = cfg.embedding_feature_name self._loss_type = cfg.loss_type self._use_clip = cfg.HasField("clip_config") self._clip_feature_name = ( @@ -70,10 +69,6 @@ def __init__( if cfg.latent_weight: rqvae_extra["latent_weight"] = list(cfg.latent_weight) - assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" - n_embed_list = list(cfg.codebook) - n_layers = len(n_embed_list) - use_sinkhorn = True sinkhorn_iters = 5 sinkhorn_epsilon = 10.0 @@ -86,8 +81,8 @@ def __init__( input_dim=cfg.input_dim, embed_dim=cfg.embed_dim, hidden_dims=hidden_dims, - n_layers=n_layers, - n_embed=n_embed_list, + n_layers=self._n_layers, + n_embed=self._n_embed_list, forward_mode=cfg.forward_mode, normalize_residuals=cfg.normalize_residuals, distance_type=cfg.distance_type, @@ -102,21 +97,6 @@ def __init__( **rqvae_extra, ) - def _extract_feature( - self, batch: Batch, feature_name: Optional[str] = None - ) -> torch.Tensor: - """Extract a named feature from Batch.dense_features. - - Args: - batch (Batch): input batch data. - feature_name (str, optional): feature name to extract. - Defaults to self._embedding_feature_name. - """ - if feature_name is None: - feature_name = self._embedding_feature_name - kt = batch.dense_features[BASE_DATA_GROUP] - return kt[feature_name] - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -176,14 +156,6 @@ def _predict_mixed( } return predictions - def init_loss(self) -> None: - """Initialize loss modules. - - Reconstruction loss and commitment loss are computed internally - by RQVAE and passed through predictions. No external loss module needed. - """ - pass - def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: @@ -207,9 +179,8 @@ def loss( return losses def init_metric(self) -> None: - """Initialize metric modules.""" - self._metric_modules["mse"] = torchmetrics.MeanMetric() - self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() + """Initialize metric modules (shared eval metrics + train-path mse).""" + super().init_metric() # Loss values are already logged by the framework via loss(); only # quantization quality needs the train-path metric. unique_sid_ratio @@ -247,13 +218,9 @@ def update_metric( batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - codes = predictions["codes"] - B = codes.shape[0] - if "x_hat" in predictions: embedding = self._extract_feature(batch) mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") self._metric_modules["mse"].update(mse) - unique_sids = torch.unique(codes, dim=0).shape[0] - self._metric_modules["unique_sid_ratio"].update(unique_sids / B) + self._update_unique_sid_ratio(predictions["codes"]) From d1cf153c53492cc4e7d968889b8400ea6e3e50fb Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 02:50:12 +0000 Subject: [PATCH 005/129] [chore] SID: bump copyright year 2024 -> 2026 on newly added files Per review (sid_rqkmeans.py:1). Applies to the 14 net-new SID source/test files added in this branch. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 2 +- tzrec/models/sid_rqkmeans.py | 2 +- tzrec/models/sid_rqkmeans_test.py | 2 +- tzrec/models/sid_rqvae.py | 2 +- tzrec/models/sid_rqvae_test.py | 2 +- tzrec/modules/sid_generation/__init__.py | 2 +- tzrec/modules/sid_generation/clip_loss.py | 2 +- tzrec/modules/sid_generation/kmeans.py | 2 +- tzrec/modules/sid_generation/residual_kmeans_quantizer.py | 2 +- tzrec/modules/sid_generation/residual_quantizer.py | 2 +- tzrec/modules/sid_generation/residual_vector_quantizer.py | 2 +- tzrec/modules/sid_generation/rqvae.py | 2 +- tzrec/modules/sid_generation/types.py | 2 +- tzrec/modules/sid_generation/vector_quantize.py | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index cb0935186..e48827a63 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 22916854a..10acd6fb6 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 2beb9da1a..8dec94d1e 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 9c8c53590..03c85312d 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index c44042c5f..51f6e6436 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index d9a414556..1c336f84b 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/clip_loss.py b/tzrec/modules/sid_generation/clip_loss.py index c3a020fd5..701576cd4 100644 --- a/tzrec/modules/sid_generation/clip_loss.py +++ b/tzrec/modules/sid_generation/clip_loss.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py index 1ebb1a64d..40027957f 100644 --- a/tzrec/modules/sid_generation/kmeans.py +++ b/tzrec/modules/sid_generation/kmeans.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index cdd55c6a0..9adfce9a8 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid_generation/residual_quantizer.py index ff238bdc7..958514b17 100644 --- a/tzrec/modules/sid_generation/residual_quantizer.py +++ b/tzrec/modules/sid_generation/residual_quantizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer.py b/tzrec/modules/sid_generation/residual_vector_quantizer.py index c5ca829f0..18ed0a480 100644 --- a/tzrec/modules/sid_generation/residual_vector_quantizer.py +++ b/tzrec/modules/sid_generation/residual_vector_quantizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/rqvae.py b/tzrec/modules/sid_generation/rqvae.py index 95e9d48e8..9d0a69beb 100644 --- a/tzrec/modules/sid_generation/rqvae.py +++ b/tzrec/modules/sid_generation/rqvae.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/types.py b/tzrec/modules/sid_generation/types.py index fce27486f..7c9ef9ec4 100644 --- a/tzrec/modules/sid_generation/types.py +++ b/tzrec/modules/sid_generation/types.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid_generation/vector_quantize.py index 9e043ffaf..d0a4ffb6a 100644 --- a/tzrec/modules/sid_generation/vector_quantize.py +++ b/tzrec/modules/sid_generation/vector_quantize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From b2bded0d1076b148a5e6a2228d82325ba66e3db7 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:01:40 +0000 Subject: [PATCH 006/129] [refactor] SID: drop the redundant RQKMeans wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (residual_kmeans.py:293 — "Why two layers of abstraction? I think RQKMeans is not needed"). RQKMeans was a thin nn.Module that just held a ResidualKMeansQuantizer and forwarded every call. SidRqkmeans now owns a ResidualKMeansQuantizer directly (self._quantizer); its forward returns (codes, quantized) so predict unpacks the tuple. Removed the RQKMeans class + export; updated tests/docstrings. Re review (sid_rqkmeans.py:88 — "use config_to_kwargs"): config_to_kwargs is currently broken framework-wide under protobuf 5.x (it passes the removed `including_default_value_fields` kwarg), so it raises on every config. Kept a version-safe MessageToDict with a NOTE pointing at the helper for when it's fixed. 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 21 +++-- tzrec/models/sid_rqkmeans_test.py | 4 +- tzrec/modules/sid_generation/__init__.py | 2 - tzrec/modules/sid_generation/kmeans.py | 4 +- .../residual_kmeans_quantizer.py | 91 +------------------ 5 files changed, 21 insertions(+), 101 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 10acd6fb6..f70bc864a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -29,7 +29,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation import RQKMeans +from tzrec.modules.sid_generation import ResidualKMeansQuantizer from tzrec.modules.sid_generation.kmeans import recon_diagnostics from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils.logging_util import logger @@ -78,13 +78,18 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message + # NOTE: the project helper ``config_util.config_to_kwargs`` would be + # the idiomatic choice here, but it passes ``MessageToDict(..., + # including_default_value_fields=True)`` which protobuf 5.x removed, + # so it raises framework-wide under the installed protobuf. Use a + # direct (version-safe) MessageToDict until that helper is fixed. self._faiss_kwargs = ( _coerce_proto_numbers(MessageToDict(cfg.faiss_kmeans_kwargs)) if cfg.HasField("faiss_kmeans_kwargs") else {} ) - self._rqkmeans = RQKMeans( + self._quantizer = ResidualKMeansQuantizer( embed_dim=cfg.input_dim, n_layers=self._n_layers, n_embed=self._n_embed_list, @@ -131,14 +136,14 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: ) } - result = self._rqkmeans(embedding) + codes, quantized = self._quantizer(embedding) predictions: Dict[str, torch.Tensor] = { - "codes": result["codes"], + "codes": codes, } if self.is_eval: - predictions["quantized"] = result["quantized"] + predictions["quantized"] = quantized predictions["input_embedding"] = embedding return predictions @@ -265,14 +270,14 @@ def on_train_end(self) -> None: "[SidRqkmeans.on_train_end] rank0 fitting FAISS " "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) ) - self._rqkmeans.train_offline(full, verbose=True) + self._quantizer.train_offline(full, verbose=True) del full # Broadcast centroids and set the init flag locally on every # rank. ``_is_initialized`` is a bool buffer and NCCL's bool # dtype support is inconsistent across versions, so we avoid # a separate broadcast for it — all ranks enter this block in # lockstep, so a local fill_() keeps state consistent. - for layer in self._rqkmeans.quantizer.layers: + for layer in self._quantizer.layers: dist.broadcast(layer.centroids, src=0) layer._is_initialized.fill_(True) dist.barrier() @@ -305,5 +310,5 @@ def on_train_end(self) -> None: # train_offline takes ownership of ``full_np`` (in-place # residual updates); drop our reference after the call. - self._rqkmeans.train_offline(full_np, verbose=True) + self._quantizer.train_offline(full_np, verbose=True) del full_np diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 8dec94d1e..4f85cfa48 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -91,7 +91,7 @@ def test_predict_collects_buffer(self) -> None: total = sum(t.shape[0] for t in model._offline_buffer) self.assertEqual(total, 4 * B) # FAISS not yet triggered: layers should be uninitialized - for layer in model._rqkmeans.quantizer.layers: + for layer in model._quantizer.layers: self.assertFalse(layer.is_initialized) def test_on_train_end_runs_faiss(self) -> None: @@ -116,7 +116,7 @@ def test_on_train_end_runs_faiss(self) -> None: # Buffer should be cleared self.assertEqual(model._offline_buffer, []) # All layers should be initialized + centroids non-zero - for layer in model._rqkmeans.quantizer.layers: + for layer in model._quantizer.layers: self.assertTrue(bool(layer._is_initialized.item())) self.assertGreater(layer.centroids.abs().sum().item(), 0.0) diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index 1c336f84b..fe7dd800e 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -17,7 +17,6 @@ ) from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, - RQKMeans, ) from tzrec.modules.sid_generation.residual_quantizer import ( ResidualQuantizer, @@ -48,5 +47,4 @@ "RQVAE", "KMeansLayer", "ResidualKMeansQuantizer", - "RQKMeans", ] diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py index 40027957f..f089e751e 100644 --- a/tzrec/modules/sid_generation/kmeans.py +++ b/tzrec/modules/sid_generation/kmeans.py @@ -15,7 +15,7 @@ SID models: * :class:`KMeansLayer` — per-layer centroid container used by - :class:`ResidualKMeansQuantizer` / :class:`RQKMeans`. Centroids are injected + :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. * :func:`_kmeans` / :func:`_residual_kmeans` — pure-torch Lloyd's @@ -100,7 +100,7 @@ def _kmeans_plus_plus( Selects initial centroids via distance-weighted probability sampling to ensure well-spread starting points. Used by the RQ-VAE codebook - init path (``ResidualVectorQuantizer.kmeans_init``); RQKMeans itself no + init path (``ResidualVectorQuantizer.kmeans_init``); the K-Means backend itself no longer needs it. Args: diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 9adfce9a8..2596993b4 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Multi-layer residual K-Means: ResidualKMeansQuantizer and RQKMeans wrapper. +"""Multi-layer residual K-Means: ResidualKMeansQuantizer. Training is FAISS-only: the codebook is built once via ``train_offline`` over the full embedding matrix; ``forward`` is read-only (predict + lookup). @@ -70,8 +70,8 @@ def __init__( # layers, so non-uniform codebooks would silently train layers 1+ # with ``K=n_embed_list[0]``. Fail fast instead. assert len(set(self.n_embed_list)) == 1, ( - "ResidualKMeansQuantizer / RQKMeans require a uniform codebook " - f"size across layers; got {self.n_embed_list}." + "ResidualKMeansQuantizer requires a uniform codebook size " + f"across layers; got {self.n_embed_list}." ) self.layers = nn.ModuleList( @@ -182,7 +182,7 @@ def train_offline( import faiss except ImportError as e: raise ImportError( - "faiss is required for RQKMeans training. Install via " + "faiss is required for ResidualKMeansQuantizer training. Install via " "`pip install faiss-cpu` or `pip install faiss-gpu`." ) from e @@ -256,86 +256,3 @@ def _calc_loss( """Reconstruction loss diagnostics (MSE + relative L1).""" loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} - - -class RQKMeans(nn.Module): - """RQ-KMeans: multi-layer residual K-Means trained offline via FAISS. - - No Encoder/Decoder — directly clusters input vectors via residual - K-Means. Codebook is built once by :meth:`train_offline`; ``forward`` - is read-only (assign + lookup). - - Args: - embed_dim (int): feature dimension. Default: 64. - n_layers (int): number of residual quantization layers. Default: 3. - n_embed (int|List[int]): number of clusters per layer. Default: 256. - normalize_residuals (bool): L2-normalize residuals before each - layer. Default: False. - faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to - ``faiss.Kmeans(...)``. - """ - - def __init__( - self, - embed_dim: int = 64, - n_layers: int = 3, - n_embed: Union[int, List[int]] = 256, - normalize_residuals: bool = False, - faiss_kmeans_kwargs: Optional[Dict] = None, - ) -> None: - super().__init__() - self.embed_dim = embed_dim - self.quantizer = ResidualKMeansQuantizer( - embed_dim=embed_dim, - n_layers=n_layers, - n_embed=n_embed, - normalize_residuals=normalize_residuals, - faiss_kmeans_kwargs=faiss_kmeans_kwargs, - ) - - def train_offline( - self, - inputs: Union[torch.Tensor, "np.ndarray"], - verbose: bool = True, - ) -> None: - """Build codebook offline via FAISS. - - Args: - inputs: full embedding matrix, shape (N, embed_dim). Either - a ``torch.Tensor`` or an ``np.ndarray`` (ownership - transferred — array is mutated in-place). - verbose (bool): print per-layer reconstruction loss. - """ - self.quantizer.train_offline(inputs, verbose=verbose) - - def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - """Forward: residual K-Means assignment (no training). - - Args: - x: (B, embed_dim) input features. - - Returns: - dict with keys: - 'codes': (B, n_layers) semantic IDs. - 'quantized': (B, embed_dim) quantized vector (sum of centroids). - """ - codes, quantized = self.quantizer(x) - return { - "codes": codes, - "quantized": quantized, - } - - @torch.no_grad() - def get_codes(self, x: torch.Tensor) -> torch.Tensor: - """Inference: get semantic IDs.""" - return self.quantizer.get_codes(x) - - @torch.no_grad() - def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: - """Reconstruct vectors from semantic IDs (centroid lookup + sum).""" - return self.quantizer.decode_codes(codes) - - @torch.no_grad() - def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: - """Get centroid weights for a specific layer.""" - return self.quantizer.get_codebook_embeddings(layer_idx) From c6efa876b6f7f90cd5cf8cfe71a6aef4a20edd4f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:04:07 +0000 Subject: [PATCH 007/129] [refactor] SID: fold RQVAE module into SidRqvae model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (rqvae.py:26 — "I think RQVAE should be refactored into sid_rqvae.py"). The encoder/decoder MLPs, the ResidualVectorQuantizer, and the CLIP head now live directly on SidRqvae; the forward_rqvae / forward_mixed / loss helpers become private model methods (_forward_rqvae / _forward_mixed / _recon_loss / _masked_recon_loss). Deleted modules/sid_generation/rqvae.py and its RQVAE export. Drops the dead bits the wrapper carried: the never-set _is_inference dispatch and the unreachable commitment_loss=None default (the proto always supplies "l2"). Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 186 +++++++++-- tzrec/models/sid_rqvae_test.py | 4 +- tzrec/modules/sid_generation/__init__.py | 4 - tzrec/modules/sid_generation/rqvae.py | 374 ----------------------- 4 files changed, 165 insertions(+), 403 deletions(-) delete mode 100644 tzrec/modules/sid_generation/rqvae.py diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 03c85312d..8ca2a1444 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -11,28 +11,41 @@ """SidRqvae: SID generation model using RQ-VAE (Encoder + VQ + Decoder). -End-to-end differentiable training with reconstruction loss -and commitment loss. Optionally supports CLIP contrastive learning. +End-to-end differentiable training with reconstruction loss and commitment +loss. Optionally supports CLIP contrastive learning. The encoder/decoder, +residual vector quantizer, and CLIP head all live directly on the model — +there is no intermediate ``RQVAE`` module wrapper. """ from typing import Any, Dict, List, Optional +import numpy as np import torch import torch.nn.functional as F import torchmetrics +from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation import RQVAE +from tzrec.modules.sid_generation.clip_loss import MaskedCLIPLoss +from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, +) from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils.logging_util import logger class SidRqvae(BaseSidModel): """SID generation model using RQ-VAE (Encoder + VQ + Decoder). - End-to-end differentiable training with reconstruction loss - and commitment loss. Optionally supports CLIP contrastive learning. + Encoder/Decoder are configurable-depth MLPs built from ``hidden_dims``: + Encoder: input_dim -> hidden_dims[0] -> ... -> embed_dim + Decoder: embed_dim -> ... -> hidden_dims[0] -> input_dim + (ReLU between hidden layers; the decoder mirrors the encoder.) + + When ``clip_config`` is set, ``predict`` runs a dual path and a masked + CLIP contrastive loss is added for the CLIP-pair rows. Args: model_config (ModelConfig): an instance of ModelConfig. @@ -41,6 +54,16 @@ class SidRqvae(BaseSidModel): sample_weights (list): sample weight names. """ + @staticmethod + def _build_mlp(dims: List[int]) -> nn.Sequential: + """Build MLP: dims[0] -> ... -> dims[-1], ReLU between hidden layers.""" + layers: List[nn.Module] = [] + for i in range(len(dims) - 1): + layers.append(nn.Linear(dims[i], dims[i + 1])) + if i < len(dims) - 2: # no activation after the last layer + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + def __init__( self, model_config: ModelConfig, @@ -53,6 +76,9 @@ def __init__( cfg = self._model_config # SidRqvae proto message self._loss_type = cfg.loss_type + assert self._loss_type in ("mse", "l1", "cosine"), ( + f"loss_type must be 'mse', 'l1' or 'cosine', got '{self._loss_type}'" + ) self._use_clip = cfg.HasField("clip_config") self._clip_feature_name = ( cfg.clip_config.clip_feature_name if self._use_clip else None @@ -61,13 +87,12 @@ def __init__( cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None ) - hidden_dims = list(cfg.hidden_dims) if cfg.hidden_dims else None - # Only forward latent_weight when the user set it (repeated field is - # empty when unset); otherwise let RQVAE / ResidualVectorQuantizer - # apply their signature default (1.0, 0.5). - rqvae_extra: Dict[str, Any] = {} - if cfg.latent_weight: - rqvae_extra["latent_weight"] = list(cfg.latent_weight) + input_dim = cfg.input_dim + embed_dim = cfg.embed_dim + hidden_dims = list(cfg.hidden_dims) if cfg.hidden_dims else [input_dim // 2] + # latent_weight defaults to (1.0, 0.5) when the user leaves the + # repeated field empty. + latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) use_sinkhorn = True sinkhorn_iters = 5 @@ -77,26 +102,141 @@ def __init__( sinkhorn_iters = cfg.sinkhorn_config.iters sinkhorn_epsilon = cfg.sinkhorn_config.epsilon - self._rqvae = RQVAE( - input_dim=cfg.input_dim, - embed_dim=cfg.embed_dim, - hidden_dims=hidden_dims, + self._encoder = self._build_mlp([input_dim, *hidden_dims, embed_dim]) + # Decoder is the symmetric reverse of the encoder. + self._decoder = self._build_mlp([embed_dim, *reversed(hidden_dims), input_dim]) + + self._quantizer = ResidualVectorQuantizer( + embed_dim=embed_dim, n_layers=self._n_layers, n_embed=self._n_embed_list, forward_mode=cfg.forward_mode, normalize_residuals=cfg.normalize_residuals, distance_type=cfg.distance_type, commitment_loss=cfg.commitment_loss, + latent_weight=latent_weight, rotation_trick=cfg.rotation_trick, kmeans_init=cfg.kmeans_init, use_sinkhorn=use_sinkhorn, sinkhorn_iters=sinkhorn_iters, sinkhorn_epsilon=sinkhorn_epsilon, - loss_type=cfg.loss_type, - use_clip=self._use_clip, - **rqvae_extra, ) + # CLIP contrastive head (optional). + if self._use_clip: + self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._masked_clip_loss_fn = MaskedCLIPLoss() + + logger.info( + "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " + "n_layers=%d, n_embed=%s, loss_type=%s, use_clip=%s", + input_dim, + embed_dim, + hidden_dims, + self._n_layers, + self._n_embed_list, + self._loss_type, + self._use_clip, + ) + + # ----- encode / decode / loss helpers (formerly RQVAE) ----- + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode. (B, input_dim) -> (B, embed_dim).""" + return self._encoder(x) + + def _decode(self, z_q: torch.Tensor) -> torch.Tensor: + """Decode. (B, embed_dim) -> (B, input_dim).""" + return self._decoder(z_q) + + def _recon_loss(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Mean reconstruction loss for the configured ``loss_type``.""" + if self._loss_type == "mse": + return F.mse_loss(x_hat, x, reduction="mean") + elif self._loss_type == "l1": + return F.l1_loss(x_hat, x, reduction="mean") + else: # 'cosine' + return (1 - F.cosine_similarity(x_hat, x, dim=1)).mean() + + def _forward_rqvae( + self, x: torch.Tensor, temperature: float = 1.0 + ) -> Dict[str, torch.Tensor]: + """Standard RQ-VAE forward: encode -> quantize -> decode -> loss.""" + z_e = self._encode(x) + quant = self._quantizer(z_e, temperature=temperature) + x_hat = self._decode(quant.quantized_embeddings) + + recon_loss = self._recon_loss(x_hat, x) + quant_loss = quant.quantization_loss + return { + "x_hat": x_hat, + "codes": quant.cluster_ids, + "quantized": quant.quantized_embeddings, + "reconstruction_loss": recon_loss, + "quantization_loss": quant_loss, + "loss": recon_loss + quant_loss, + } + + def _masked_recon_loss( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + recon_mask: torch.Tensor, + ) -> torch.Tensor: + """Per-sample recon loss masked to recon rows (no data-dependent branch).""" + if self._loss_type == "mse": + per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) + elif self._loss_type == "l1": + per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + else: # 'cosine' + per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) + n_recon = recon_mask.float().sum().clamp(min=1) + return (per_sample * recon_mask.float()).sum() / n_recon + + def _forward_mixed( + self, + fea1: torch.Tensor, + fea2: torch.Tensor, + clip_mask: torch.Tensor, + temperature: float = 1.0, + ) -> Dict[str, torch.Tensor]: + """Mixed recon + CLIP forward (all rows dual-pathed; mask splits loss).""" + z_e1 = self._encode(fea1) + quant1 = self._quantizer(z_e1, temperature=temperature) + x_hat1 = self._decode(quant1.quantized_embeddings) + + z_e2 = self._encode(fea2) + quant2 = self._quantizer(z_e2, temperature=temperature) + x_hat2 = self._decode(quant2.quantized_embeddings) + + recon_mask = ~clip_mask + recon_loss = self._masked_recon_loss(x_hat1, fea1, recon_mask) + + features = { + "image_embed": x_hat1, + "text_embed": x_hat2, + "image_embed_ori": fea1, + "text_embed_ori": fea2, + "logit_scale_self": self._logit_scale_self.exp(), + "logit_scale_cl": self._logit_scale_cl.exp(), + "logit_scale": self._logit_scale.exp(), + } + clip_result = self._masked_clip_loss_fn(features, clip_mask) + + commitment = (quant1.quantization_loss + quant2.quantization_loss) / 2 + return { + "codes": quant1.cluster_ids, + "quantized": quant1.quantized_embeddings, + "x_hat": x_hat1, + "recon_loss": recon_loss, + "clip_loss": clip_result["clip_loss"], + "commitment_loss": commitment, + } + + # ----- BaseModel interface ----- + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -115,7 +255,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: """Standard RQ-VAE: encode -> quantize -> decode -> loss.""" - result = self._rqvae.forward_rqvae(embedding) + result = self._forward_rqvae(embedding) predictions: Dict[str, torch.Tensor] = { "codes": result["codes"], @@ -132,11 +272,11 @@ def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: def _predict_mixed( self, embedding: torch.Tensor, batch: Batch ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP: extract fea2 and clip_mask, call forward_mixed.""" + """Mixed recon + CLIP: extract fea2 and clip_mask, run the dual path.""" # Inference skips the dual path: fea2 / clip_mask aren't needed # when we only emit codes. if self._is_inference: - result = self._rqvae.forward_rqvae(embedding) + result = self._forward_rqvae(embedding) return {"codes": result["codes"]} fea2 = self._extract_feature(batch, self._clip_feature_name) @@ -144,7 +284,7 @@ def _predict_mixed( is_clip_pair_raw = self._extract_feature(batch, self._is_clip_pair_feature_name) clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 - result = self._rqvae.forward_mixed(embedding, fea2, clip_mask) + result = self._forward_mixed(embedding, fea2, clip_mask) predictions: Dict[str, torch.Tensor] = { "codes": result["codes"], diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 51f6e6436..c8e3eafa8 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -373,7 +373,7 @@ def test_sinkhorn_config_enabled_false(self) -> None: model = SidRqvae(model_config=model_config, features=[], labels=[]) init_parameters(model, device=torch.device("cpu")) - for layer in model._rqvae.quantizer.layers: + for layer in model._quantizer.layers: self.assertFalse(layer.use_sinkhorn) def test_sinkhorn_config_default_enabled(self) -> None: @@ -382,7 +382,7 @@ def test_sinkhorn_config_default_enabled(self) -> None: Back-compat for legacy configs that never set the sub-config. """ model = self._create_model() # no sinkhorn_config set - for layer in model._rqvae.quantizer.layers: + for layer in model._quantizer.layers: self.assertTrue(layer.use_sinkhorn) def test_commitment_loss_invalid_raises(self) -> None: diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index fe7dd800e..e466f057e 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -24,9 +24,6 @@ from tzrec.modules.sid_generation.residual_vector_quantizer import ( ResidualVectorQuantizer, ) -from tzrec.modules.sid_generation.rqvae import ( - RQVAE, -) from tzrec.modules.sid_generation.types import ( QuantizeForwardMode, QuantizeOutput, @@ -44,7 +41,6 @@ "GatherLayer", "ResidualQuantizer", "ResidualVectorQuantizer", - "RQVAE", "KMeansLayer", "ResidualKMeansQuantizer", ] diff --git a/tzrec/modules/sid_generation/rqvae.py b/tzrec/modules/sid_generation/rqvae.py deleted file mode 100644 index 9d0a69beb..000000000 --- a/tzrec/modules/sid_generation/rqvae.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""RQVAE: Encoder + ResidualVectorQuantizer + Decoder top-level wrapper.""" - -from typing import Dict, List, Optional, Sequence, Union - -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F - -from tzrec.modules.sid_generation.clip_loss import MaskedCLIPLoss -from tzrec.modules.sid_generation.residual_vector_quantizer import ( - ResidualVectorQuantizer, -) -from tzrec.utils.logging_util import logger - - -class RQVAE(nn.Module): - """RQ-VAE: Encoder + ResidualVectorQuantizer + Decoder. - - Supports optional CLIP contrastive learning. When use_clip=True, - forward accepts paired inputs (fea1, fea2) and computes CLIP loss - via a siamese network (shared parameters). - - Encoder/Decoder are configurable-depth MLPs built via hidden_dims: - Encoder: input_dim -> hidden_dims[0] -> ... -> hidden_dims[-1] -> embed_dim - Decoder: embed_dim -> hidden_dims[-1] -> ... -> hidden_dims[0] -> input_dim - ReLU activation between hidden layers. Decoder reverses hidden_dims - for symmetric structure. - - Args: - input_dim (int): original embedding dimension. Default: 512. - embed_dim (int): latent space dimension. Default: 64. - hidden_dims (List[int]): encoder hidden layer dimensions. - Decoder automatically reverses for symmetry. - Default: [input_dim // 2]. - n_layers (int): number of residual quantization layers. Default: 3. - n_embed (int|List[int]): codebook size per layer. Default: 256. - forward_mode (str): VQ forward mode ('ste'|'gumbel_softmax'). - Default: 'ste'. - normalize_residuals (bool): L2-normalize residuals. Default: False. - distance_type (str|List[str]): distance metric ('l2'|'cosine'). - Default: 'l2'. - commitment_loss (str|None): commitment loss type ('l2'|'cos'). - Default: follows loss_type (al_sid behavior). - latent_weight (List[float]): commitment loss weights [w1, w2]. - Default: [1.0, 0.5]. - rotation_trick (bool): STE rotation trick. Default: False. - kmeans_init (bool): KMeans codebook initialization. Default: True. - use_sinkhorn (bool): Sinkhorn uniform assignment. Default: True. - sinkhorn_iters (int): Sinkhorn iterations. Default: 5. - sinkhorn_epsilon (float): Sinkhorn sharpness. Default: 10.0. - loss_type (str): reconstruction loss ('mse'|'l1'|'cosine'). - Default: 'mse'. - use_clip (bool): enable CLIP contrastive learning. Default: False. - """ - - @staticmethod - def _build_mlp(dims: List[int]) -> nn.Sequential: - """Build MLP: dims[0] -> ... -> dims[-1], ReLU between hidden layers.""" - layers: List[nn.Module] = [] - for i in range(len(dims) - 1): - layers.append(nn.Linear(dims[i], dims[i + 1])) - if i < len(dims) - 2: # no activation after last layer - layers.append(nn.ReLU()) - return nn.Sequential(*layers) - - def __init__( - self, - input_dim: int = 512, - embed_dim: int = 64, - hidden_dims: Optional[List[int]] = None, - n_layers: int = 3, - n_embed: Union[int, List[int]] = 256, - forward_mode: str = "ste", - normalize_residuals: bool = False, - distance_type: Union[str, List[str]] = "l2", - commitment_loss: Optional[str] = None, - latent_weight: Sequence[float] = (1.0, 0.5), - rotation_trick: bool = False, - kmeans_init: bool = True, - use_sinkhorn: bool = True, - sinkhorn_iters: int = 5, - sinkhorn_epsilon: float = 10.0, - loss_type: str = "mse", - use_clip: bool = False, - ) -> None: - super().__init__() - - assert loss_type in ("mse", "l1", "cosine"), ( - f"loss_type must be 'mse', 'l1' or 'cosine', got '{loss_type}'" - ) - self.loss_type = loss_type - self.use_clip = use_clip - self.input_dim = input_dim - self.embed_dim = embed_dim - - self._is_inference = False - - if hidden_dims is None: - hidden_dims = [input_dim // 2] - - # commitment_loss defaults to follow loss_type (al_sid behavior: - # commitment_loss=loss_type, so mse -> l2 branch) - if commitment_loss is None: - commitment_loss = "l2" if loss_type == "mse" else loss_type - - enc_dims = [input_dim] + list(hidden_dims) + [embed_dim] - self.encoder = self._build_mlp(enc_dims) - - # Decoder is the symmetric reverse of the encoder. - dec_dims = [embed_dim] + list(reversed(hidden_dims)) + [input_dim] - self.decoder = self._build_mlp(dec_dims) - - self.quantizer = ResidualVectorQuantizer( - embed_dim=embed_dim, - n_layers=n_layers, - n_embed=n_embed, - forward_mode=forward_mode, - normalize_residuals=normalize_residuals, - distance_type=distance_type, - commitment_loss=commitment_loss, - latent_weight=latent_weight, - rotation_trick=rotation_trick, - kmeans_init=kmeans_init, - use_sinkhorn=use_sinkhorn, - sinkhorn_iters=sinkhorn_iters, - sinkhorn_epsilon=sinkhorn_epsilon, - ) - - # CLIP contrastive learning (optional) - if use_clip: - self.logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self.masked_clip_loss_fn = MaskedCLIPLoss() - - logger.info( - "RQVAE init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " - "n_layers=%d, n_embed=%s, forward_mode=%s, " - "normalize_residuals=%s, distance_type=%s, " - "commitment_loss=%s, latent_weight=%s, rotation_trick=%s, " - "kmeans_init=%s, use_sinkhorn=%s, " - "sinkhorn_iters=%d, sinkhorn_epsilon=%s, " - "loss_type=%s, use_clip=%s", - input_dim, - embed_dim, - hidden_dims, - n_layers, - n_embed, - forward_mode, - normalize_residuals, - distance_type, - commitment_loss, - list(latent_weight), - rotation_trick, - kmeans_init, - use_sinkhorn, - sinkhorn_iters, - sinkhorn_epsilon, - loss_type, - use_clip, - ) - - def encode(self, x: torch.Tensor) -> torch.Tensor: - """Encode. (B, input_dim) -> (B, embed_dim).""" - return self.encoder(x) - - def decode(self, z_q: torch.Tensor) -> torch.Tensor: - """Decode. (B, embed_dim) -> (B, input_dim).""" - return self.decoder(z_q) - - def _cosine_loss(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: - """Cosine distance loss: 1 - mean(cos_sim).""" - return (1 - F.cosine_similarity(x1, x2, dim=1)).mean() - - def compute_loss( - self, - x: torch.Tensor, - x_hat: torch.Tensor, - quant_loss: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - """Compute reconstruction loss + quantization loss + total loss. - - loss_total = recon_loss + quant_loss - Note: al_sid latent_loss_weight is declared but unused; - commitment_loss is added 1:1 with recon_loss. We align with this. - - Args: - x: original input, shape (B, input_dim). - x_hat: reconstructed output, shape (B, input_dim). - quant_loss: quantization (commitment) loss scalar. - - Returns: - dict with 'reconstruction_loss', 'quantization_loss', 'loss'. - """ - if self.loss_type == "mse": - recon_loss = F.mse_loss(x_hat, x, reduction="mean") - elif self.loss_type == "l1": - recon_loss = F.l1_loss(x_hat, x, reduction="mean") - elif self.loss_type == "cosine": - recon_loss = self._cosine_loss(x_hat, x) - else: - raise ValueError(f"Unsupported loss_type: {self.loss_type}") - - loss_total = recon_loss + quant_loss - - return { - "reconstruction_loss": recon_loss, - "quantization_loss": quant_loss, - "loss": loss_total, - } - - def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]: - """Dispatch based on use_clip. - - use_clip=False: forward(x) -> forward_rqvae(x) - use_clip=True: forward(fea1, fea2, clip_mask) -> forward_mixed(...) - """ - if self._is_inference or not self.use_clip: - assert len(args) >= 1, "Standard mode requires (x,)" - return self.forward_rqvae(args[0], **kwargs) - else: - assert len(args) == 3, "Mixed mode requires (fea1, fea2, clip_mask)" - return self.forward_mixed(args[0], args[1], args[2], **kwargs) - - def forward_rqvae( - self, x: torch.Tensor, temperature: float = 1.0 - ) -> Dict[str, torch.Tensor]: - """Standard RQ-VAE forward: encode -> quantize -> decode -> loss. - - Args: - x: (B, input_dim) original embedding. - temperature: Gumbel-Softmax temperature. - - Returns: - dict with keys: 'x_hat', 'codes', 'quantized', - 'reconstruction_loss', 'quantization_loss', 'loss'. - """ - z_e = self.encode(x) - quant_output = self.quantizer(z_e, temperature=temperature) - x_hat = self.decode(quant_output.quantized_embeddings) - - losses = self.compute_loss(x, x_hat, quant_output.quantization_loss) - - return { - "x_hat": x_hat, - "codes": quant_output.cluster_ids, - "quantized": quant_output.quantized_embeddings, - **losses, - } - - def _compute_masked_recon_loss( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - recon_mask: torch.Tensor, - ) -> torch.Tensor: - """Compute per-sample recon loss, masked to recon rows only. - - No boolean indexing, no data-dependent branching, - compatible with torch.compile. - - Args: - x_hat: (B, D) reconstructed output. - x: (B, D) original input. - recon_mask: (B,) bool, True = recon row. - """ - if self.loss_type == "mse": - per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif self.loss_type == "l1": - per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - elif self.loss_type == "cosine": - per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) - else: - raise ValueError(f"Unsupported loss_type: {self.loss_type}") - n_recon = recon_mask.float().sum().clamp(min=1) - return (per_sample * recon_mask.float()).sum() / n_recon - - def forward_mixed( - self, - fea1: torch.Tensor, - fea2: torch.Tensor, - clip_mask: torch.Tensor, - temperature: float = 1.0, - ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP forward. - - All samples go through dual paths; mask separates recon and clip - loss contributions. - - Args: - fea1: (B, input_dim) main embedding (all rows valid). - fea2: (B, input_dim) clip embedding (recon rows == fea1). - clip_mask: (B,) bool, True = clip sample. - temperature: Gumbel-Softmax temperature. - """ - # Step 1: dual-path encode -> quantize -> decode - z_e1 = self.encode(fea1) - quant1 = self.quantizer(z_e1, temperature=temperature) - x_hat1 = self.decode(quant1.quantized_embeddings) - - z_e2 = self.encode(fea2) - quant2 = self.quantizer(z_e2, temperature=temperature) - x_hat2 = self.decode(quant2.quantized_embeddings) - - # Step 2: recon loss (only recon rows, no branching) - recon_mask = ~clip_mask - recon_loss = self._compute_masked_recon_loss(x_hat1, fea1, recon_mask) - - # Step 3: masked CLIP loss (only clip rows) - features = { - "image_embed": x_hat1, - "text_embed": x_hat2, - "image_embed_ori": fea1, - "text_embed_ori": fea2, - "logit_scale_self": self.logit_scale_self.exp(), - "logit_scale_cl": self.logit_scale_cl.exp(), - "logit_scale": self.logit_scale.exp(), - } - clip_result = self.masked_clip_loss_fn(features, clip_mask) - - # Step 4: commitment loss (average of two paths) - commitment = (quant1.quantization_loss + quant2.quantization_loss) / 2 - - return { - "codes": quant1.cluster_ids, - "quantized": quant1.quantized_embeddings, - "x_hat": x_hat1, - "recon_loss": recon_loss, - "clip_loss": clip_result["clip_loss"], - "clip_acc": clip_result["clip_acc"], - "loss_self": clip_result["loss_self"], - "loss_ori": clip_result["loss_ori"], - "loss_cl": clip_result["loss_cl"], - "commitment_loss": commitment, - "loss": recon_loss + clip_result["clip_loss"] + commitment, - } - - @torch.no_grad() - def get_codes(self, x: torch.Tensor) -> torch.Tensor: - """Inference: get semantic IDs. - - Args: - x: (B, input_dim) original embedding. - - Returns: - Tensor: codes, shape (B, n_layers). - """ - z_e = self.encode(x) - return self.quantizer.get_codes(z_e) - - @torch.no_grad() - def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: - """Reconstruct embedding from semantic IDs (through decoder). - - Args: - codes: (B, n_layers) semantic ID codes. - - Returns: - Tensor: x_hat, shape (B, input_dim). - """ - z_q = self.quantizer.decode_codes(codes) - return self.decode(z_q) From 0bd7a44b6c1965ec04afa21998962e48c7c92366 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:06:52 +0000 Subject: [PATCH 008/129] [refactor] SID: address review nits (std_mean, test config, .cpu() doc) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - vector_quantize.py:204 — use torch.std_mean instead of torch.var_mean + rsqrt for the pre-Sinkhorn z-score (cleaner, equivalent). - sid_rq{vae,kmeans}_test.py — drop the feature_groups that referenced "item_emb" while features=[] (SID models read the dense feature directly and never consume feature_groups); config is now internally consistent. - sid_rqkmeans.py — document why predict buffers to host (.cpu()): the full corpus is accumulated before one FAISS pass, so GPU residency would OOM and faiss-cpu can't take CUDA tensors. 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 +++++ tzrec/models/sid_rqkmeans_test.py | 11 +++-------- tzrec/models/sid_rqvae_test.py | 19 +++---------------- .../modules/sid_generation/vector_quantize.py | 4 ++-- 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index f70bc864a..6199e6b24 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -121,6 +121,11 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: # Training: buffer for the end-of-loop FAISS fit and return dummy # codes — the codebook does not exist yet. + # We move to host (.cpu()) deliberately: the whole corpus is + # accumulated before the single FAISS pass, so keeping every step's + # batch resident in GPU memory would OOM, and the common faiss-cpu + # build cannot consume CUDA tensors anyway. (A faiss-gpu fit could + # take a GPU tensor, but that is the exception, not the default.) # TODO(perf): .cpu() is a synchronous D2H per step and the buffer # grows unbounded with steps. Rework to either (a) GPU-resident # buffer + bulk D2H in on_train_end with size cap, or (b) replace diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 4f85cfa48..b9442476c 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -52,15 +52,10 @@ def _create_model(self, input_dim=32, n_layers=2, niter=5): faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", ) - feature_groups = [ - model_pb2.FeatureGroupConfig( - group_name="deep", - feature_names=["item_emb"], - group_type=model_pb2.FeatureGroupType.DEEP, - ), - ] + # SID models read the item-embedding dense feature directly from the + # batch; they do not consume feature_groups, so none is set (which + # keeps the config consistent with the empty ``features`` list). model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqkmeans=sid_rqkmeans_cfg, ) model = SidRqkmeans(model_config=model_config, features=[], labels=[]) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index c8e3eafa8..686a9a720 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -65,15 +65,10 @@ def _create_model(self, use_clip=False, input_dim=32, embed_dim=8, n_layers=2): ) ) - feature_groups = [ - model_pb2.FeatureGroupConfig( - group_name="deep", - feature_names=["item_emb"], - group_type=model_pb2.FeatureGroupType.DEEP, - ), - ] + # SID models read the item-embedding dense feature directly from the + # batch; they do not consume feature_groups, so none is set (which + # keeps the config consistent with the empty ``features`` list). model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, ) model = SidRqvae(model_config=model_config, features=[], labels=[]) @@ -359,15 +354,7 @@ def test_sinkhorn_config_enabled_false(self) -> None: sid_rqvae_cfg.sinkhorn_config.CopyFrom( sid_model_pb2.SinkhornConfig(enabled=False) ) - feature_groups = [ - model_pb2.FeatureGroupConfig( - group_name="deep", - feature_names=["item_emb"], - group_type=model_pb2.FeatureGroupType.DEEP, - ), - ] model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, ) model = SidRqvae(model_config=model_config, features=[], labels=[]) diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid_generation/vector_quantize.py index d0a4ffb6a..bc429ca75 100644 --- a/tzrec/modules/sid_generation/vector_quantize.py +++ b/tzrec/modules/sid_generation/vector_quantize.py @@ -201,8 +201,8 @@ def _find_nearest_embedding( if self.training and self.use_sinkhorn: # Sinkhorn requires non-negative cost; z-score then shift. - var, mean = torch.var_mean(distances, unbiased=False) - distances = (distances - mean) * var.add(1e-12).rsqrt() + std, mean = torch.std_mean(distances, unbiased=False) + distances = (distances - mean) / std.add(1e-12) distances = distances - distances.min() # Sinkhorn optimal-transport assignment From e867af286ad3f3a6e18f0f353fa52254d0cf8636 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:09:53 +0000 Subject: [PATCH 009/129] [test] SID: add module-level unit tests for sid_generation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (__init__.py:12 — "Add unit tests for all modules and functions"). Adds colocated *_test.py covering the modules that previously only had indirect, model-level coverage: - kmeans_test.py: recon_diagnostics, _squared_euclidean_distance (+ chunked-equivalence), _kmeans / _residual_kmeans shapes, and the KMeansLayer load/predict/round-trip + mid-fit-checkpoint guard. - vector_quantize_test.py: VectorQuantize STE/Gumbel x l2/cosine x sinkhorn forward, STE gradient-to-input, eval plain-lookup. - residual_quantizer_test.py: normalize_n_embed, the abstract base's shared output_dim/decode_codes + NotImplementedError primitives, and both subclasses (ResidualVectorQuantizer / ResidualKMeansQuantizer) incl. the non-uniform-codebook reject and an offline FAISS fit. - clip_loss_test.py: single-process MaskedCLIPLoss (all-clip finite, all-recon zero, backward-to-embeddings) and _all_gather_with_grad single-process identity. 51/51 SID unit tests pass (18 model + 33 module). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../modules/sid_generation/clip_loss_test.py | 75 +++++++++++ tzrec/modules/sid_generation/kmeans_test.py | 108 +++++++++++++++ .../sid_generation/residual_quantizer_test.py | 127 ++++++++++++++++++ .../sid_generation/vector_quantize_test.py | 71 ++++++++++ 4 files changed, 381 insertions(+) create mode 100644 tzrec/modules/sid_generation/clip_loss_test.py create mode 100644 tzrec/modules/sid_generation/kmeans_test.py create mode 100644 tzrec/modules/sid_generation/residual_quantizer_test.py create mode 100644 tzrec/modules/sid_generation/vector_quantize_test.py diff --git a/tzrec/modules/sid_generation/clip_loss_test.py b/tzrec/modules/sid_generation/clip_loss_test.py new file mode 100644 index 000000000..227f8afaa --- /dev/null +++ b/tzrec/modules/sid_generation/clip_loss_test.py @@ -0,0 +1,75 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from tzrec.modules.sid_generation.clip_loss import ( + MaskedCLIPLoss, + _all_gather_with_grad, +) + + +class AllGatherWithGradTest(unittest.TestCase): + def test_single_process_identity(self) -> None: + a, b = torch.randn(3, 4), torch.randn(3, 4) + out = _all_gather_with_grad([a, b]) + self.assertIs(out[0], a) + self.assertIs(out[1], b) + + +class MaskedCLIPLossTest(unittest.TestCase): + """Single-process tests for the masked CLIP loss.""" + + def _features(self, B: int, D: int) -> dict: + torch.manual_seed(0) + scale = torch.tensor(np.log(1 / 0.07)).exp() + return { + "image_embed": torch.randn(B, D, requires_grad=True), + "text_embed": torch.randn(B, D, requires_grad=True), + "image_embed_ori": torch.randn(B, D), + "text_embed_ori": torch.randn(B, D), + "logit_scale_self": scale, + "logit_scale_cl": scale, + "logit_scale": scale, + } + + def test_forward_all_clip_finite(self) -> None: + loss_fn = MaskedCLIPLoss() + feats = self._features(6, 8) + mask = torch.ones(6, dtype=torch.bool) + out = loss_fn(feats, mask) + self.assertIn("clip_loss", out) + self.assertTrue(torch.isfinite(out["clip_loss"])) + self.assertGreater(out["clip_loss"].item(), 0.0) + + def test_all_recon_mask_zero_loss(self) -> None: + loss_fn = MaskedCLIPLoss() + feats = self._features(6, 8) + mask = torch.zeros(6, dtype=torch.bool) # no clip rows + out = loss_fn(feats, mask) + # No clip rows -> masked average is exactly zero (and finite). + self.assertTrue(torch.isfinite(out["clip_loss"])) + self.assertAlmostEqual(out["clip_loss"].item(), 0.0, places=6) + + def test_backward_flows_to_embeddings(self) -> None: + loss_fn = MaskedCLIPLoss() + feats = self._features(6, 8) + mask = torch.ones(6, dtype=torch.bool) + loss_fn(feats, mask)["clip_loss"].backward() + self.assertIsNotNone(feats["image_embed"].grad) + self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid_generation/kmeans_test.py b/tzrec/modules/sid_generation/kmeans_test.py new file mode 100644 index 000000000..531b33126 --- /dev/null +++ b/tzrec/modules/sid_generation/kmeans_test.py @@ -0,0 +1,108 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid_generation.kmeans import ( + KMeansLayer, + _kmeans, + _residual_kmeans, + _squared_euclidean_distance, + recon_diagnostics, +) + + +class KmeansHelpersTest(unittest.TestCase): + """Tests for the pure-torch K-Means helpers.""" + + def test_recon_diagnostics_zero_on_identity(self) -> None: + x = torch.randn(8, 4) + mse, rel = recon_diagnostics(x, x.clone()) + self.assertAlmostEqual(mse.item(), 0.0, places=6) + self.assertAlmostEqual(rel.item(), 0.0, places=6) + + def test_squared_euclidean_distance(self) -> None: + x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + d = _squared_euclidean_distance(x, y) + self.assertEqual(d.shape, (2, 2)) + # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) + + def test_squared_euclidean_distance_chunked_matches(self) -> None: + x = torch.randn(120, 5) + y = torch.randn(7, 5) + full = _squared_euclidean_distance(x, y, chunk_size=1000) + chunked = _squared_euclidean_distance(x, y, chunk_size=16) + torch.testing.assert_close(full, chunked) + + def test_kmeans_shapes_and_assignment_range(self) -> None: + torch.manual_seed(0) + samples = torch.randn(200, 6) + centroids, assignments = _kmeans(samples, n_clusters=8, n_iters=5) + self.assertEqual(centroids.shape, (8, 6)) + self.assertEqual(assignments.shape, (200,)) + self.assertTrue((assignments >= 0).all() and (assignments < 8).all()) + + def test_residual_kmeans_per_layer_centers(self) -> None: + torch.manual_seed(0) + samples = torch.randn(200, 6) + centers = _residual_kmeans(samples, [8, 4], n_iters=5) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + + +class KMeansLayerTest(unittest.TestCase): + """Tests for the single KMeansLayer.""" + + def test_uninitialized_by_default(self) -> None: + layer = KMeansLayer(n_clusters=4, n_features=3) + self.assertFalse(layer.is_initialized) + self.assertEqual(layer.centroids.abs().sum().item(), 0.0) + + def test_load_centroids_and_predict(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) + layer.load_centroids_(centroids) + self.assertTrue(layer.is_initialized) + + batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) + codes = layer.predict(batch) + torch.testing.assert_close(codes, torch.tensor([0, 1])) + + def test_load_centroids_shape_mismatch_raises(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaises(AssertionError): + layer.load_centroids_(torch.zeros(3, 2)) + + def test_mid_fit_checkpoint_rejected(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + sd = layer.state_dict() + # Simulate a mid-fit checkpoint: flag True but centroids still zero. + sd["_is_initialized"] = torch.tensor(True) + fresh = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + def test_post_fit_checkpoint_round_trips(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh.load_state_dict(layer.state_dict()) + self.assertTrue(fresh.is_initialized) + torch.testing.assert_close(fresh.centroids, layer.centroids) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py new file mode 100644 index 000000000..f5893cf4c --- /dev/null +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -0,0 +1,127 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) +from tzrec.modules.sid_generation.residual_quantizer import ( + ResidualQuantizer, + normalize_n_embed, +) +from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, +) +from tzrec.modules.sid_generation.types import ResidualQuantizerOutput + + +class NormalizeNEmbedTest(unittest.TestCase): + def test_scalar_broadcasts(self) -> None: + self.assertEqual(normalize_n_embed(256, 3), [256, 256, 256]) + + def test_list_passes_through(self) -> None: + self.assertEqual(normalize_n_embed([8, 4, 2], 3), [8, 4, 2]) + + def test_length_mismatch_raises(self) -> None: + with self.assertRaises(AssertionError): + normalize_n_embed([8, 4], 3) + + +class ResidualQuantizerBaseTest(unittest.TestCase): + """The abstract base owns shared state but not the backend primitives.""" + + def test_shared_state_and_output_dim(self) -> None: + rq = ResidualQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertEqual(rq.output_dim(), 4) + self.assertEqual(rq.n_embed_list, [8, 8]) + self.assertEqual(len(rq.layers), 0) # subclasses populate this + + def test_abstract_primitives_raise(self) -> None: + rq = ResidualQuantizer(embed_dim=4, n_layers=2) + x = torch.randn(3, 4) + with self.assertRaises(NotImplementedError): + rq.forward(x) + with self.assertRaises(NotImplementedError): + rq.get_codes(x) + with self.assertRaises(NotImplementedError): + rq.get_codebook_embeddings(0) + # decode_codes is concrete but delegates to the abstract _lookup_code. + with self.assertRaises(NotImplementedError): + rq.decode_codes(torch.zeros(3, 2, dtype=torch.long)) + + +class ResidualVectorQuantizerTest(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(0) + self.rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=3, n_embed=16, kmeans_init=False + ) + + def test_is_subclass(self) -> None: + self.assertIsInstance(self.rvq, ResidualQuantizer) + + def test_forward_output(self) -> None: + self.rvq.train() + out = self.rvq(torch.randn(5, 8)) + self.assertIsInstance(out, ResidualQuantizerOutput) + self.assertEqual(out.cluster_ids.shape, (5, 3)) + self.assertEqual(out.quantized_embeddings.shape, (5, 8)) + self.assertTrue(torch.isfinite(out.quantization_loss).all()) + + def test_decode_codes_shared_base(self) -> None: + codes = torch.randint(0, 16, (5, 3)) + recon = self.rvq.decode_codes(codes) + self.assertEqual(recon.shape, (5, 8)) + + def test_get_codes_no_grad(self) -> None: + codes = self.rvq.get_codes(torch.randn(4, 8)) + self.assertEqual(codes.shape, (4, 3)) + + +class ResidualKMeansQuantizerTest(unittest.TestCase): + def test_is_subclass(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertIsInstance(rkq, ResidualQuantizer) + + def test_non_uniform_codebook_rejected(self) -> None: + with self.assertRaises(AssertionError): + ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=[8, 4]) + + def test_forward_returns_zeros_before_fit(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertFalse(rkq.all_initialized) + codes, quantized = rkq(torch.randn(5, 4)) + self.assertEqual(codes.shape, (5, 2)) + self.assertEqual(quantized.shape, (5, 4)) + + def test_train_offline_then_decode(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + self.assertTrue(rkq.all_initialized) + + codes, _ = rkq(torch.randn(5, 4)) + self.assertTrue((codes >= 0).all() and (codes < 8).all()) + recon = rkq.decode_codes(codes) # inherited from the base + self.assertEqual(recon.shape, (5, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid_generation/vector_quantize_test.py b/tzrec/modules/sid_generation/vector_quantize_test.py new file mode 100644 index 000000000..833f36231 --- /dev/null +++ b/tzrec/modules/sid_generation/vector_quantize_test.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.modules.sid_generation.types import QuantizeForwardMode +from tzrec.modules.sid_generation.vector_quantize import VectorQuantize + + +class VectorQuantizeTest(unittest.TestCase): + """Tests for a single VectorQuantize layer.""" + + @parameterized.expand( + [ + ("ste_l2", QuantizeForwardMode.STE, "l2", True), + ("ste_cosine", QuantizeForwardMode.STE, "cosine", True), + ("ste_no_sinkhorn", QuantizeForwardMode.STE, "l2", False), + ("gumbel_l2", QuantizeForwardMode.GUMBEL_SOFTMAX, "l2", True), + ] + ) + def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: + torch.manual_seed(0) + vq = VectorQuantize( + embed_dim=8, + n_embed=16, + forward_mode=mode, + distance_type=distance_type, + use_sinkhorn=use_sinkhorn, + ) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + out = vq(x) + self.assertEqual(out.embeddings.shape, (5, 8)) + self.assertEqual(out.ids.shape, (5,)) + self.assertTrue((out.ids >= 0).all() and (out.ids < 16).all()) + self.assertTrue(torch.isfinite(out.embeddings).all()) + + def test_train_forward_backward_reaches_input(self) -> None: + torch.manual_seed(0) + vq = VectorQuantize(embed_dim=8, n_embed=16, use_sinkhorn=False) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + out = vq(x) + out.embeddings.sum().backward() + # STE routes gradient back through x. + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all()) + + def test_eval_forward_is_plain_lookup(self) -> None: + torch.manual_seed(0) + vq = VectorQuantize(embed_dim=4, n_embed=8) + vq.eval() + x = torch.randn(3, 4) + out = vq(x) + # In eval, emb == embedding(ids) exactly. + torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) + + +if __name__ == "__main__": + unittest.main() From 78a3ce2fb314ab4003caa8a18c071ec33c90701f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:29:06 +0000 Subject: [PATCH 010/129] [refactor] SID clip_loss: use built-in differentiable all_gather + dist test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (clip_loss.py:50 — "Why not directly use torch.distributed.nn.functional.all_gather?"). Replace the hand-rolled GatherLayer autograd.Function with torch.distributed.nn.functional .all_gather inside _all_gather_with_grad; its backward already sum-reduces the per-rank grads and returns this rank's slice, so the custom Function (and its GatherLayer export) are gone. Adds clip_loss_dist_test.py: a 2-rank multi-process test (NCCL on GPU when >=2 devices, else gloo/CPU) asserting all_gather forward values, the world_size-summed backward, and a MaskedCLIPLoss forward/backward across ranks. Validated on 2x GPU (NCCL). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid_generation/__init__.py | 4 +- tzrec/modules/sid_generation/clip_loss.py | 41 ++---- .../sid_generation/clip_loss_dist_test.py | 131 ++++++++++++++++++ 3 files changed, 141 insertions(+), 35 deletions(-) create mode 100644 tzrec/modules/sid_generation/clip_loss_dist_test.py diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index e466f057e..916f1f94c 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from tzrec.modules.sid_generation.clip_loss import ( - GatherLayer, + MaskedCLIPLoss, ) from tzrec.modules.sid_generation.kmeans import ( KMeansLayer, @@ -38,7 +38,7 @@ "QuantizeOutput", "ResidualQuantizerOutput", "VectorQuantize", - "GatherLayer", + "MaskedCLIPLoss", "ResidualQuantizer", "ResidualVectorQuantizer", "KMeansLayer", diff --git a/tzrec/modules/sid_generation/clip_loss.py b/tzrec/modules/sid_generation/clip_loss.py index 701576cd4..16fd93908 100644 --- a/tzrec/modules/sid_generation/clip_loss.py +++ b/tzrec/modules/sid_generation/clip_loss.py @@ -11,50 +11,25 @@ """CLIP contrastive learning loss with distributed all-gather support.""" -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch import torch.distributed as dist +import torch.distributed.nn as dist_nn from torch import nn from torch.nn import functional as F -class GatherLayer(torch.autograd.Function): - """Gather tensors from all workers with gradient support. - - Standard ``dist.all_gather`` detaches gradients; this custom - ``autograd.Function`` keeps the computation graph connected so - that contrastive losses can backpropagate through gathered tensors. - """ - - @staticmethod - def forward(ctx, x: torch.Tensor) -> Tuple[torch.Tensor, ...]: - """All-gather ``x`` across ranks, returning one tensor per rank.""" - output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] - dist.all_gather(output, x) - return tuple(output) - - @staticmethod - def backward(ctx, *grads: torch.Tensor) -> torch.Tensor: - """Sum-reduce the per-rank grads and return this rank's slice. - - ``all_reduce`` is sum, so reducing only this rank's slice gives - the same result as stacking + reducing + slicing, but avoids - materialising the full ``(world_size, B, D)`` buffer. - """ - grad_local = grads[dist.get_rank()].contiguous() - dist.all_reduce(grad_local) - return grad_local - - def _all_gather_with_grad( tensors: List[torch.Tensor], ) -> List[torch.Tensor]: """All-gather tensors across distributed workers with gradient support. - In single-process mode, returns input tensors unchanged. - In multi-process mode, uses GatherLayer for backward-compatible - all_gather. + In single-process mode, returns input tensors unchanged. In + multi-process mode, uses ``torch.distributed.nn.functional.all_gather`` + — the built-in differentiable collective (its backward sum-reduces the + per-rank grads and returns this rank's slice), so no custom + ``autograd.Function`` is needed. Args: tensors (List[Tensor]): list of tensors to gather. @@ -67,7 +42,7 @@ def _all_gather_with_grad( gathered: List[torch.Tensor] = [] for tensor in tensors: - tensor_all = GatherLayer.apply(tensor) + tensor_all = dist_nn.all_gather(tensor) # differentiable, one per rank gathered.append(torch.cat(tensor_all, dim=0)) return gathered diff --git a/tzrec/modules/sid_generation/clip_loss_dist_test.py b/tzrec/modules/sid_generation/clip_loss_dist_test.py new file mode 100644 index 000000000..178c54dbe --- /dev/null +++ b/tzrec/modules/sid_generation/clip_loss_dist_test.py @@ -0,0 +1,131 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-process tests for the CLIP distributed all-gather path. + +Validates ``_all_gather_with_grad`` (built on the differentiable +``torch.distributed.nn.functional.all_gather``) and ``MaskedCLIPLoss`` +across ranks. Uses NCCL on GPU when >=2 devices are available (the +production path the reviewer cared about), else falls back to gloo/CPU, +so the test is runnable on a multi-GPU box and in CPU CI alike. +""" + +import os +import unittest + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tzrec.modules.sid_generation.clip_loss import ( + MaskedCLIPLoss, + _all_gather_with_grad, +) +from tzrec.utils import misc_util + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _all_gather_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Each rank holds a distinct, rank-identifying tensor. + x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) + gathered = _all_gather_with_grad([x])[0] + + # Forward: gathered is (world_size*2, 3); rank r contributes rows + # [2r : 2r+2] all equal to (r+1). + assert gathered.shape == (world_size * 2, 3), gathered.shape + for r in range(world_size): + block = gathered[2 * r : 2 * r + 2] + assert torch.allclose(block, torch.full_like(block, float(r + 1))), ( + f"rank{rank}: gathered block {r} wrong: {block}" + ) + + # Backward: identical scalar loss on every rank -> grad to every gathered + # element is 1; the differentiable all_gather sum-reduces across ranks, + # so the local input grad is world_size * ones. + gathered.sum().backward() + assert x.grad is not None, f"rank{rank}: no grad" + assert torch.isfinite(x.grad).all(), f"rank{rank}: non-finite grad" + expected = torch.full_like(x, float(world_size)) + assert torch.allclose(x.grad, expected), f"rank{rank}: grad {x.grad} != {expected}" + dist.destroy_process_group() + + +def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + torch.manual_seed(1234 + rank) + B, D = 4, 8 + scale = torch.tensor(np.log(1 / 0.07)).exp().to(device) + feats = { + "image_embed": torch.randn(B, D, device=device, requires_grad=True), + "text_embed": torch.randn(B, D, device=device, requires_grad=True), + "image_embed_ori": torch.randn(B, D, device=device), + "text_embed_ori": torch.randn(B, D, device=device), + "logit_scale_self": scale, + "logit_scale_cl": scale, + "logit_scale": scale, + } + mask = torch.ones(B, dtype=torch.bool, device=device) + + loss_fn = MaskedCLIPLoss().to(device) + out = loss_fn(feats, mask) + clip_loss = out["clip_loss"] + assert torch.isfinite(clip_loss).all(), f"rank{rank}: non-finite clip_loss" + assert clip_loss.item() > 0.0, f"rank{rank}: clip_loss not positive" + + clip_loss.backward() + g = feats["image_embed"].grad + assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" + dist.destroy_process_group() + + +def _run(target) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +class ClipLossDistTest(unittest.TestCase): + """2-rank tests for the CLIP distributed collectives.""" + + def test_all_gather_with_grad(self) -> None: + _run(_all_gather_worker) + + def test_masked_clip_loss(self) -> None: + _run(_masked_clip_worker) + + +if __name__ == "__main__": + unittest.main() From f4851a32e48ccfc08a7920a3e8580a1e16f6b635 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:31:47 +0000 Subject: [PATCH 011/129] [test] SID: add multi-rank test for SidRqkmeans.on_train_end DDP path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (sid_rqkmeans_test.py:133 — the single-rank test never enters the DDP branch; the gather_object/broadcast/_is_initialized path was untested). Adds a 2-rank multi-process test (NCCL on GPU when >=2 devices, else gloo) that fills each rank's offline buffer, runs on_train_end, and asserts: every rank ends initialized with non-zero, cross-rank-identical (broadcast) centroids, and eval predict emits valid in-range codes. Empirically refutes the "gather_object is incompatible with NCCL" concern from review for the pinned torch: the path completes on NCCL (validated on 2x GPU). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans_dist_test.py | 139 +++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 tzrec/models/sid_rqkmeans_dist_test.py diff --git a/tzrec/models/sid_rqkmeans_dist_test.py b/tzrec/models/sid_rqkmeans_dist_test.py new file mode 100644 index 000000000..1b511780b --- /dev/null +++ b/tzrec/models/sid_rqkmeans_dist_test.py @@ -0,0 +1,139 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-process tests for SidRqkmeans.on_train_end's DDP code path. + +This exercises the collective sequence the single-process unit test +cannot reach: the cross-rank empty-buffer all_reduce, ``gather_object`` +of the per-rank embedding buffers to rank 0, the FAISS fit, and the +``broadcast`` of centroids + ``_is_initialized`` fill on every rank. + +Uses NCCL on GPU when >=2 devices are available (the production backend +the reviewer flagged for ``gather_object``), else gloo/CPU. +""" + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.protos import model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils import misc_util + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _make_batch(batch_size: int, input_dim: int, device: torch.device) -> Batch: + dense = KeyedTensor.from_tensor_list( + keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim, device=device)] + ) + return Batch( + dense_features={BASE_DATA_GROUP: dense}, sparse_features={}, labels={} + ) + + +def _create_model(input_dim: int, n_layers: int, k: int): + from google.protobuf.struct_pb2 import Struct + + from tzrec.models.sid_rqkmeans import SidRqkmeans + + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": 5, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=[k] * n_layers, + normalize_residuals=False, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + ) + model_config = model_pb2.ModelConfig(sid_rqkmeans=cfg) + return SidRqkmeans(model_config=model_config, features=[], labels=[]) + + +def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _create_model(input_dim, n_layers, k).to(device) + model.train() + + torch.manual_seed(100 + rank) + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + assert len(model._offline_buffer) == 6, f"rank{rank}: buffer not filled" + + # The collective sequence under test: empty-flag all_reduce -> + # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. + model.on_train_end() + + # Every rank must end initialized with non-zero centroids. + for layer in model._quantizer.layers: + assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" + assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" + + # Centroids were broadcast from rank0 -> must be bit-identical across + # ranks (min == max under all_reduce). + for layer in model._quantizer.layers: + cmin = layer.centroids.clone() + cmax = layer.centroids.clone() + dist.all_reduce(cmin, op=dist.ReduceOp.MIN) + dist.all_reduce(cmax, op=dist.ReduceOp.MAX) + assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" + + # After the fit, eval predict emits valid codes. + model.eval() + codes = model.predict(_make_batch(8, input_dim, device))["codes"] + assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" + assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" + dist.destroy_process_group() + + +def _run(target) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +class SidRqkmeansDistTest(unittest.TestCase): + """2-rank test for SidRqkmeans.on_train_end.""" + + def test_on_train_end_ddp(self) -> None: + _run(_on_train_end_worker) + + +if __name__ == "__main__": + unittest.main() From 26bcecd0511156ba8b1ae3020d5825e5590da7d0 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:43:21 +0000 Subject: [PATCH 012/129] [refactor] SidRqkmeans.on_train_end: drop redundant empty-buffer handshake MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (tiankongdeguiji, sid_rqkmeans.py:250 — "In synchronized training, some ranks shouldn't have empty data—this logic is redundant"). The dataset layer already guarantees it: file-based datasets enforce `num_files >= world_size` (tzrec/datasets/dataset.py raises otherwise), so in synchronized DDP training every rank receives at least one shard and reaches the gather with a non-empty buffer. The cross-rank all_reduce(MAX) empty-flag handshake was therefore dead insurance. Removed it: the DDP branch now goes straight to gather_object/fit/ broadcast. The single-process branch keeps a plain local empty-buffer no-op guard (not a collective) so on_train_end without a training pass still degrades gracefully. Verified: single-process unit tests (incl. empty-buffer no-op) and the 2-rank NCCL on_train_end test pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 101 +++++++++++++++-------------------- 1 file changed, 44 insertions(+), 57 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 6199e6b24..a7be1f9a4 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -212,45 +212,23 @@ def on_train_end(self) -> None: """Trigger one-shot FAISS fit after the train_eval loop ends. Overrides :meth:`BaseModel.on_train_end`. Called unconditionally - by ``tzrec.main.train_and_evaluate`` after the training loop - exits. No-op when the buffer is empty. + by ``tzrec.main.train_and_evaluate`` after the training loop exits. DDP behavior: - rank0: receive local buffers via gather_object, concat, run FAISS fit, then broadcast centroids to other ranks. - other ranks: ship local buffer via gather_object(dst=0) and wait for the broadcast. + + No cross-rank empty-buffer handshake is needed: the dataset layer + enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` + raises otherwise), so in synchronized training every rank receives + at least one shard and reaches the gather with a non-empty buffer. """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 ) - # A local-only empty check would deadlock: the empty rank returns - # while peers block in gather_object below. OR the flag across - # ranks and bail together if any rank is empty. - local_empty = len(self._offline_buffer) == 0 - if is_ddp: - # int32, not bool — NCCL bool support is version-dependent. - flag = torch.tensor( - int(local_empty), - dtype=torch.int32, - device=self._dummy_param.device, - ) - dist.all_reduce(flag, op=dist.ReduceOp.MAX) - any_empty = bool(flag.item()) - else: - any_empty = local_empty - - if any_empty: - if (not is_ddp) or dist.get_rank() == 0: - logger.warning( - "[SidRqkmeans.on_train_end] at least one rank has an " - "empty offline buffer; skipping FAISS fit on all ranks. " - "Did the train_eval loop run, and is the per-rank shard " - "non-empty?" - ) - return - if is_ddp: # DDP path: every rank ships its local buffer to rank 0 via # gather_object (variable-length pickle — fine for this one- @@ -286,34 +264,43 @@ def on_train_end(self) -> None: dist.broadcast(layer.centroids, src=0) layer._is_initialized.fill_(True) dist.barrier() - else: - # Single-process path: build the full numpy matrix directly - # from the buffer list, popping each chunk after copy so the - # transient memory high-water mark stays ~= final matrix size - # (instead of 2× when going through torch.cat). - N = sum(t.shape[0] for t in self._offline_buffer) - D = self._offline_buffer[0].shape[1] - logger.info( - "[SidRqkmeans.on_train_end] fitting FAISS on " - "%d samples (D=%d)." % (N, D) + return + + # Single-process path. Guard an empty buffer with a plain local + # check (no collective): on_train_end may be invoked without a + # training pass having run. + if len(self._offline_buffer) == 0: + logger.warning( + "[SidRqkmeans.on_train_end] empty offline buffer; skipping " + "FAISS fit. Did the train_eval loop run?" ) - full_np = np.empty((N, D), dtype=np.float32) - offset = 0 - # Pop from the front; each popped tensor is released before - # the next copy so cumulative torch memory shrinks monotonically. - while self._offline_buffer: - t = self._offline_buffer.pop(0) - n = t.shape[0] - # .float().numpy() returns a view sharing storage with - # the fp32 tensor; the subsequent assignment copies into - # full_np, after which ``t`` can be freed. - full_np[offset : offset + n] = t.float().numpy() - offset += n - del t - del self._offline_buffer - self._offline_buffer = [] + return - # train_offline takes ownership of ``full_np`` (in-place - # residual updates); drop our reference after the call. - self._quantizer.train_offline(full_np, verbose=True) - del full_np + # Build the full numpy matrix directly from the buffer list, popping + # each chunk after copy so the transient memory high-water mark stays + # ~= final matrix size (instead of 2× when going through torch.cat). + N = sum(t.shape[0] for t in self._offline_buffer) + D = self._offline_buffer[0].shape[1] + logger.info( + "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (N, D) + ) + full_np = np.empty((N, D), dtype=np.float32) + offset = 0 + # Pop from the front; each popped tensor is released before the next + # copy so cumulative torch memory shrinks monotonically. + while self._offline_buffer: + t = self._offline_buffer.pop(0) + n = t.shape[0] + # .float().numpy() returns a view sharing storage with the fp32 + # tensor; the subsequent assignment copies into full_np, after + # which ``t`` can be freed. + full_np[offset : offset + n] = t.float().numpy() + offset += n + del t + del self._offline_buffer + self._offline_buffer = [] + + # train_offline takes ownership of ``full_np`` (in-place residual + # updates); drop our reference after the call. + self._quantizer.train_offline(full_np, verbose=True) + del full_np From e75800691a242e087da6086ded4052bb100dec01 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:49:03 +0000 Subject: [PATCH 013/129] [refactor] SID: move CLIP loss into tzrec/loss MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (clip_loss.py:75 — "refactor into tzrec/loss"). MaskedCLIPLoss is a generic contrastive loss, not a SID quantization primitive, and has no sid_generation dependencies — so it belongs with the other loss modules (focal_loss, jrc_loss, ...). - tzrec/modules/sid_generation/clip_loss.py -> tzrec/loss/clip_loss.py (+ colocated clip_loss_test.py / clip_loss_dist_test.py). - SidRqvae imports it from tzrec.loss.clip_loss; dropped the MaskedCLIPLoss re-export from sid_generation/__init__.py. Behavior unchanged; single-process + 2-rank NCCL clip tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/{modules/sid_generation => loss}/clip_loss.py | 0 tzrec/{modules/sid_generation => loss}/clip_loss_dist_test.py | 2 +- tzrec/{modules/sid_generation => loss}/clip_loss_test.py | 2 +- tzrec/models/sid_rqvae.py | 2 +- tzrec/modules/sid_generation/__init__.py | 4 ---- 5 files changed, 3 insertions(+), 7 deletions(-) rename tzrec/{modules/sid_generation => loss}/clip_loss.py (100%) rename tzrec/{modules/sid_generation => loss}/clip_loss_dist_test.py (98%) rename tzrec/{modules/sid_generation => loss}/clip_loss_test.py (98%) diff --git a/tzrec/modules/sid_generation/clip_loss.py b/tzrec/loss/clip_loss.py similarity index 100% rename from tzrec/modules/sid_generation/clip_loss.py rename to tzrec/loss/clip_loss.py diff --git a/tzrec/modules/sid_generation/clip_loss_dist_test.py b/tzrec/loss/clip_loss_dist_test.py similarity index 98% rename from tzrec/modules/sid_generation/clip_loss_dist_test.py rename to tzrec/loss/clip_loss_dist_test.py index 178c54dbe..d0824e1ee 100644 --- a/tzrec/modules/sid_generation/clip_loss_dist_test.py +++ b/tzrec/loss/clip_loss_dist_test.py @@ -26,7 +26,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tzrec.modules.sid_generation.clip_loss import ( +from tzrec.loss.clip_loss import ( MaskedCLIPLoss, _all_gather_with_grad, ) diff --git a/tzrec/modules/sid_generation/clip_loss_test.py b/tzrec/loss/clip_loss_test.py similarity index 98% rename from tzrec/modules/sid_generation/clip_loss_test.py rename to tzrec/loss/clip_loss_test.py index 227f8afaa..703b3c00c 100644 --- a/tzrec/modules/sid_generation/clip_loss_test.py +++ b/tzrec/loss/clip_loss_test.py @@ -14,7 +14,7 @@ import numpy as np import torch -from tzrec.modules.sid_generation.clip_loss import ( +from tzrec.loss.clip_loss import ( MaskedCLIPLoss, _all_gather_with_grad, ) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 8ca2a1444..1478bd268 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -27,8 +27,8 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature +from tzrec.loss.clip_loss import MaskedCLIPLoss from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation.clip_loss import MaskedCLIPLoss from tzrec.modules.sid_generation.residual_vector_quantizer import ( ResidualVectorQuantizer, ) diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index 916f1f94c..d6c3e9350 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -9,9 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tzrec.modules.sid_generation.clip_loss import ( - MaskedCLIPLoss, -) from tzrec.modules.sid_generation.kmeans import ( KMeansLayer, ) @@ -38,7 +35,6 @@ "QuantizeOutput", "ResidualQuantizerOutput", "VectorQuantize", - "MaskedCLIPLoss", "ResidualQuantizer", "ResidualVectorQuantizer", "KMeansLayer", From d2a0032ed637e02ac6bfad1b4227b4712a5ecdf9 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 03:53:58 +0000 Subject: [PATCH 014/129] [refactor] clip_loss: MaskedCLIPLoss subclasses _Loss; fold in all_gather Per review follow-up: - MaskedCLIPLoss now subclasses torch.nn.modules.loss._Loss (matching the tzrec/loss convention, e.g. BinaryFocalLoss) instead of bare nn.Module. - The module-level _all_gather_with_grad helper had MaskedCLIPLoss as its only (production) caller, so it becomes a private @staticmethod MaskedCLIPLoss._all_gather_with_grad alongside _gather_bool_mask. Tests updated to the static-method form; single-process + 2-rank NCCL clip tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss.py | 59 +++++++++++++++---------------- tzrec/loss/clip_loss_dist_test.py | 7 ++-- tzrec/loss/clip_loss_test.py | 7 ++-- 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index 16fd93908..f6d1097d5 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -16,38 +16,11 @@ import torch import torch.distributed as dist import torch.distributed.nn as dist_nn -from torch import nn from torch.nn import functional as F +from torch.nn.modules.loss import _Loss -def _all_gather_with_grad( - tensors: List[torch.Tensor], -) -> List[torch.Tensor]: - """All-gather tensors across distributed workers with gradient support. - - In single-process mode, returns input tensors unchanged. In - multi-process mode, uses ``torch.distributed.nn.functional.all_gather`` - — the built-in differentiable collective (its backward sum-reduces the - per-rank grads and returns this rank's slice), so no custom - ``autograd.Function`` is needed. - - Args: - tensors (List[Tensor]): list of tensors to gather. - - Returns: - List[Tensor]: gathered tensors, each (world_size * B, ...). - """ - if not dist.is_initialized() or dist.get_world_size() == 1: - return tensors - - gathered: List[torch.Tensor] = [] - for tensor in tensors: - tensor_all = dist_nn.all_gather(tensor) # differentiable, one per rank - gathered.append(torch.cat(tensor_all, dim=0)) - return gathered - - -class MaskedCLIPLoss(nn.Module): +class MaskedCLIPLoss(_Loss): """Masked CLIP loss for mixed recon+clip batches. In a mixed batch, recon rows (clip_mask=False) should not @@ -79,6 +52,30 @@ def __init__(self) -> None: self.last_local_batch_size: Optional[int] = None self._rank = dist.get_rank() if dist.is_initialized() else 0 + @staticmethod + def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """All-gather tensors across workers with gradient support. + + In single-process mode, returns the input tensors unchanged. In + multi-process mode, uses ``torch.distributed.nn.functional + .all_gather`` — the built-in differentiable collective (its backward + sum-reduces the per-rank grads and returns this rank's slice), so no + custom ``autograd.Function`` is needed. + + Args: + tensors (List[Tensor]): list of tensors to gather. + + Returns: + List[Tensor]: gathered tensors, each (world_size * B, ...). + """ + if not dist.is_initialized() or dist.get_world_size() == 1: + return tensors + gathered: List[torch.Tensor] = [] + for tensor in tensors: + tensor_all = dist_nn.all_gather(tensor) # differentiable, per rank + gathered.append(torch.cat(tensor_all, dim=0)) + return gathered + @staticmethod def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor: """All-gather bool mask across distributed workers.""" @@ -145,10 +142,10 @@ def forward( text_embed = F.normalize(text_embed, dim=-1, p=2) # All-gather across GPUs (with gradient support) - image_embed_all, text_embed_all = _all_gather_with_grad( + image_embed_all, text_embed_all = self._all_gather_with_grad( [image_embed, text_embed] ) - image_embed_all_ori, text_embed_all_ori = _all_gather_with_grad( + image_embed_all_ori, text_embed_all_ori = self._all_gather_with_grad( [image_embed_ori, text_embed_ori] ) diff --git a/tzrec/loss/clip_loss_dist_test.py b/tzrec/loss/clip_loss_dist_test.py index d0824e1ee..80d80cf0a 100644 --- a/tzrec/loss/clip_loss_dist_test.py +++ b/tzrec/loss/clip_loss_dist_test.py @@ -26,10 +26,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tzrec.loss.clip_loss import ( - MaskedCLIPLoss, - _all_gather_with_grad, -) +from tzrec.loss.clip_loss import MaskedCLIPLoss from tzrec.utils import misc_util WORLD_SIZE = 2 @@ -53,7 +50,7 @@ def _all_gather_worker(rank: int, world_size: int, port: int) -> None: device = _init(rank, world_size, port) # Each rank holds a distinct, rank-identifying tensor. x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) - gathered = _all_gather_with_grad([x])[0] + gathered = MaskedCLIPLoss._all_gather_with_grad([x])[0] # Forward: gathered is (world_size*2, 3); rank r contributes rows # [2r : 2r+2] all equal to (r+1). diff --git a/tzrec/loss/clip_loss_test.py b/tzrec/loss/clip_loss_test.py index 703b3c00c..f124c2ff8 100644 --- a/tzrec/loss/clip_loss_test.py +++ b/tzrec/loss/clip_loss_test.py @@ -14,16 +14,13 @@ import numpy as np import torch -from tzrec.loss.clip_loss import ( - MaskedCLIPLoss, - _all_gather_with_grad, -) +from tzrec.loss.clip_loss import MaskedCLIPLoss class AllGatherWithGradTest(unittest.TestCase): def test_single_process_identity(self) -> None: a, b = torch.randn(3, 4), torch.randn(3, 4) - out = _all_gather_with_grad([a, b]) + out = MaskedCLIPLoss._all_gather_with_grad([a, b]) self.assertIs(out[0], a) self.assertIs(out[1], b) From 0a9ca3b1f5b81596ddd279b798e9dd9b37ce2599 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 05:33:21 +0000 Subject: [PATCH 015/129] [refactor] SidRqkmeans: use config_util.config_to_kwargs for faiss kwargs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (tiankongdeguiji, sid_rqkmeans.py:88 — "use config_to_kwargs"). Replace the bespoke MessageToDict call with the project-standard helper that ~35 other models already use (rank_model, match_model, dlrm, ...). config_to_kwargs returns Struct numbers as floats, so _coerce_proto_numbers is kept to restore the ints faiss.Kmeans expects (niter/seed/nredo). Note: config_to_kwargs passes MessageToDict(..., including_default_value_ fields=...), which protobuf 5.x renamed/removed — so it (like every other config_to_kwargs caller in tzrec) requires protobuf 4.x. Validated on the supported env (protobuf 4.25.9): the SidRqkmeans suite passes. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index a7be1f9a4..2e1e28760 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -23,7 +23,6 @@ import torch import torch.distributed as dist import torchmetrics -from google.protobuf.json_format import MessageToDict from torch import nn from tzrec.datasets.utils import Batch @@ -32,6 +31,7 @@ from tzrec.modules.sid_generation import ResidualKMeansQuantizer from tzrec.modules.sid_generation.kmeans import recon_diagnostics from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils import config_util from tzrec.utils.logging_util import logger @@ -78,13 +78,11 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message - # NOTE: the project helper ``config_util.config_to_kwargs`` would be - # the idiomatic choice here, but it passes ``MessageToDict(..., - # including_default_value_fields=True)`` which protobuf 5.x removed, - # so it raises framework-wide under the installed protobuf. Use a - # direct (version-safe) MessageToDict until that helper is fixed. + # config_to_kwargs returns Struct numbers as floats (it is + # MessageToDict under the hood), so _coerce_proto_numbers restores + # the ints faiss.Kmeans expects (niter, seed, nredo, ...). self._faiss_kwargs = ( - _coerce_proto_numbers(MessageToDict(cfg.faiss_kmeans_kwargs)) + _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) if cfg.HasField("faiss_kmeans_kwargs") else {} ) From a0daba8a677ef2ed785ca14397e5034430d5c9ca Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 06:03:28 +0000 Subject: [PATCH 016/129] [perf] SidRqkmeans: self-tuning reservoir buffer (bounded host memory) Replaces the unbounded per-step offline buffer (every embedding .cpu()'d and kept) with Vitter Algorithm-R reservoir sampling into a fixed-size host buffer, fixing the rank-0 OOM risk on large corpora. Self-tuning cap: FAISS K-Means only ever consumes K*max_points_per_centroid points (it subsamples internally), so that is the target. New proto field train_sample_size (0 = auto) sets the global target; the per-rank cap is target/world_size so the gathered set on rank0 is ~target and FAISS does no further subsampling. With the default max_points_per_centroid=256 and K=256 that's ~65K rows/layer instead of the whole corpus. on_train_end now consumes the reservoir sample directly (gather -> fit -> broadcast) and releases it. Tests updated to the reservoir state; added test_reservoir_caps_memory. Algorithm verified uniform + capped in isolation; model paths validated on the remote (protobuf 4.x) below. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 160 +++++++++++++++---------- tzrec/models/sid_rqkmeans_dist_test.py | 2 +- tzrec/models/sid_rqkmeans_test.py | 33 +++-- tzrec/protos/models/sid_model.proto | 6 + 4 files changed, 127 insertions(+), 74 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 2e1e28760..dd67c0bbf 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -19,7 +19,6 @@ from typing import Any, Dict, List, Optional -import numpy as np import torch import torch.distributed as dist import torchmetrics @@ -87,6 +86,7 @@ def __init__( else {} ) + self._input_dim = cfg.input_dim self._quantizer = ResidualKMeansQuantizer( embed_dim=cfg.input_dim, n_layers=self._n_layers, @@ -95,14 +95,81 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - # CPU buffer for embeddings collected during training; FAISS - # consumes it in on_train_end() at end-of-loop. - self._offline_buffer: List[torch.Tensor] = [] + # Per-rank reservoir cap. FAISS K-Means only ever consumes + # K * max_points_per_centroid points (it subsamples internally), so + # buffering the full corpus is wasted memory. We reservoir-sample to + # that target instead, split across ranks so the gathered set on + # rank0 is ~train_sample_size and FAISS does no further subsampling. + k = self._n_embed_list[0] + max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) + global_target = ( + cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc + ) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + self._sample_cap = max(1, -(-global_target // world_size)) # ceil div + + # Bounded host-resident reservoir (allocated lazily on first batch, + # once the embedding dim/device is known). ``_n_filled`` slots hold + # data; ``_n_seen`` is the running count for the sampling probability. + self._reservoir: Optional[torch.Tensor] = None + self._n_filled = 0 + self._n_seen = 0 # KMeans has no learnable parameters (centroids use register_buffer). # Add dummy param to keep optimizer/DDP happy. self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + @torch.no_grad() + def _reservoir_add(self, x: torch.Tensor) -> None: + """Add a batch to the bounded reservoir (Vitter's Algorithm R). + + Keeps a uniform random ``self._sample_cap`` subset of every embedding + seen so far in O(cap) host memory, in a single streaming pass. + + Args: + x (Tensor): a batch of embeddings, shape (B, D); copied to host. + """ + x = x.detach().to("cpu", dtype=torch.float32) + cap = self._sample_cap + if self._reservoir is None: + self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) + + # Phase 1: fill empty slots first. + if self._n_filled < cap: + take = min(x.shape[0], cap - self._n_filled) + self._reservoir[self._n_filled : self._n_filled + take] = x[:take] + self._n_filled += take + self._n_seen += take + x = x[take:] + if x.shape[0] == 0: + return + + # Phase 2: replacement. Row j (0-indexed in x) is the + # (n_seen + j)-th item seen; it enters the reservoir with prob + # cap / (n_seen + j + 1), displacing a uniformly-random slot. + r = x.shape[0] + pos = self._n_seen + torch.arange(r) + accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) + idx = accept.nonzero(as_tuple=True)[0] + if idx.numel() > 0: + slots = torch.randint(0, cap, (idx.numel(),)) + # Intra-batch slot collisions resolve last-write-wins; the bias is + # O(B/cap) per step and negligible for codebook fitting. + self._reservoir[slots] = x[idx] + self._n_seen += r + + def _reservoir_sample(self) -> torch.Tensor: + """Return the filled portion of the reservoir, shape (n_filled, D).""" + if self._reservoir is None or self._n_filled == 0: + return torch.empty(0, self._input_dim, dtype=torch.float32) + return self._reservoir[: self._n_filled] + + def _reset_reservoir(self) -> None: + """Drop the reservoir after the FAISS fit to free host memory.""" + self._reservoir = None + self._n_filled = 0 + self._n_seen = 0 + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -117,21 +184,12 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """ embedding = self._extract_feature(batch) - # Training: buffer for the end-of-loop FAISS fit and return dummy - # codes — the codebook does not exist yet. - # We move to host (.cpu()) deliberately: the whole corpus is - # accumulated before the single FAISS pass, so keeping every step's - # batch resident in GPU memory would OOM, and the common faiss-cpu - # build cannot consume CUDA tensors anyway. (A faiss-gpu fit could - # take a GPU tensor, but that is the exception, not the default.) - # TODO(perf): .cpu() is a synchronous D2H per step and the buffer - # grows unbounded with steps. Rework to either (a) GPU-resident - # buffer + bulk D2H in on_train_end with size cap, or (b) replace - # the train pass with an inference_mode corpus walk launched from - # on_train_end. Skipped here to avoid OOM-vs-refactor tradeoffs; - # tracked separately. + # Training: reservoir-sample into a bounded host buffer for the + # end-of-loop FAISS fit, and return dummy codes — the codebook does + # not exist yet. The reservoir caps memory at _sample_cap rows + # regardless of corpus size (FAISS only consumes a subset anyway). if self.is_train: - self._offline_buffer.append(embedding.detach().cpu()) + self._reservoir_add(embedding) B = embedding.shape[0] return { "codes": torch.zeros( @@ -213,30 +271,28 @@ def on_train_end(self) -> None: by ``tzrec.main.train_and_evaluate`` after the training loop exits. DDP behavior: - - rank0: receive local buffers via gather_object, concat, - run FAISS fit, then broadcast centroids to other ranks. - - other ranks: ship local buffer via gather_object(dst=0) - and wait for the broadcast. + - rank0: receive each rank's reservoir sample via gather_object, + concat, run FAISS fit, then broadcast centroids to all ranks. + - other ranks: ship their reservoir sample via gather_object + (dst=0) and wait for the broadcast. No cross-rank empty-buffer handshake is needed: the dataset layer enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` raises otherwise), so in synchronized training every rank receives - at least one shard and reaches the gather with a non-empty buffer. + at least one shard and reaches the gather with a non-empty sample. """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 ) - if is_ddp: - # DDP path: every rank ships its local buffer to rank 0 via - # gather_object (variable-length pickle — fine for this one- - # shot, CPU-resident gather). Only rank 0 holds the corpus, - # so peak memory is O(world_size) on rank 0 and O(1) elsewhere - # (vs O(world_size²) for all_gather_object). - local = torch.cat(self._offline_buffer, dim=0) - del self._offline_buffer - self._offline_buffer = [] + local = self._reservoir_sample() + self._reset_reservoir() + if is_ddp: + # DDP path: every rank ships its reservoir sample to rank 0 via + # gather_object. Each sample is bounded by _sample_cap, so the + # gathered set on rank0 is ~train_sample_size and FAISS does no + # further subsampling. rank = dist.get_rank() gathered: Optional[List[Optional[torch.Tensor]]] = ( [None] * dist.get_world_size() if rank == 0 else None @@ -264,41 +320,17 @@ def on_train_end(self) -> None: dist.barrier() return - # Single-process path. Guard an empty buffer with a plain local - # check (no collective): on_train_end may be invoked without a - # training pass having run. - if len(self._offline_buffer) == 0: + # Single-process path. Guard an empty sample with a plain local check + # (no collective): on_train_end may be invoked without a training pass. + if local.shape[0] == 0: logger.warning( - "[SidRqkmeans.on_train_end] empty offline buffer; skipping " - "FAISS fit. Did the train_eval loop run?" + "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " + "fit. Did the train_eval loop run?" ) return - # Build the full numpy matrix directly from the buffer list, popping - # each chunk after copy so the transient memory high-water mark stays - # ~= final matrix size (instead of 2× when going through torch.cat). - N = sum(t.shape[0] for t in self._offline_buffer) - D = self._offline_buffer[0].shape[1] logger.info( - "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (N, D) + "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." + % (local.shape[0], local.shape[1]) ) - full_np = np.empty((N, D), dtype=np.float32) - offset = 0 - # Pop from the front; each popped tensor is released before the next - # copy so cumulative torch memory shrinks monotonically. - while self._offline_buffer: - t = self._offline_buffer.pop(0) - n = t.shape[0] - # .float().numpy() returns a view sharing storage with the fp32 - # tensor; the subsequent assignment copies into full_np, after - # which ``t`` can be freed. - full_np[offset : offset + n] = t.float().numpy() - offset += n - del t - del self._offline_buffer - self._offline_buffer = [] - - # train_offline takes ownership of ``full_np`` (in-place residual - # updates); drop our reference after the call. - self._quantizer.train_offline(full_np, verbose=True) - del full_np + self._quantizer.train_offline(local, verbose=True) diff --git a/tzrec/models/sid_rqkmeans_dist_test.py b/tzrec/models/sid_rqkmeans_dist_test.py index 1b511780b..82a96f2ac 100644 --- a/tzrec/models/sid_rqkmeans_dist_test.py +++ b/tzrec/models/sid_rqkmeans_dist_test.py @@ -86,7 +86,7 @@ def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: torch.manual_seed(100 + rank) for _ in range(6): model.predict(_make_batch(32, input_dim, device)) - assert len(model._offline_buffer) == 6, f"rank{rank}: buffer not filled" + assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" # The collective sequence under test: empty-flag all_reduce -> # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index b9442476c..03cd2c920 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -68,10 +68,11 @@ def test_proto_parse(self) -> None: self.assertEqual(model._faiss_kwargs.get("niter"), 5) self.assertEqual(model._faiss_kwargs.get("seed"), 1234) self.assertFalse(model._faiss_kwargs.get("verbose")) - self.assertEqual(model._offline_buffer, []) + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) def test_predict_collects_buffer(self) -> None: - """In train mode, predict should append to buffer; never fit.""" + """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 model = self._create_model(input_dim=input_dim) model.train() @@ -81,14 +82,27 @@ def test_predict_collects_buffer(self) -> None: preds = model.predict(batch) self.assertIn("codes", preds) - # Buffer accumulates 4 batches of B samples each - self.assertEqual(len(model._offline_buffer), 4) - total = sum(t.shape[0] for t in model._offline_buffer) - self.assertEqual(total, 4 * B) + # Reservoir holds all 4*B samples (well under the cap) and tracks + # the running count. + self.assertEqual(model._n_seen, 4 * B) + self.assertEqual(model._n_filled, 4 * B) # FAISS not yet triggered: layers should be uninitialized for layer in model._quantizer.layers: self.assertFalse(layer.is_initialized) + def test_reservoir_caps_memory(self) -> None: + """Reservoir bounds the buffer at _sample_cap regardless of corpus.""" + B, input_dim = 16, 8 + model = self._create_model(input_dim=input_dim) + model._sample_cap = 10 # force a tiny cap + model._reset_reservoir() + model.train() + for _ in range(20): # 320 rows >> cap + model.predict(_make_batch(B, input_dim)) + self.assertEqual(model._n_seen, 20 * B) + self.assertEqual(model._n_filled, 10) + self.assertEqual(model._reservoir.shape, (10, input_dim)) + def test_on_train_end_runs_faiss(self) -> None: """on_train_end triggers FAISS fit and clears buffer.""" try: @@ -103,13 +117,14 @@ def test_on_train_end_runs_faiss(self) -> None: # Accumulate enough samples (FAISS K-Means needs at least K points) for _ in range(8): model.predict(_make_batch(B, input_dim)) - self.assertGreater(len(model._offline_buffer), 0) + self.assertGreater(model._n_seen, 0) # Trigger one-shot FAISS fit model.on_train_end() - # Buffer should be cleared - self.assertEqual(model._offline_buffer, []) + # Reservoir should be released after the fit + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) # All layers should be initialized + centroids non-zero for layer in model._quantizer.layers: self.assertTrue(bool(layer._is_initialized.item())) diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 0385d5728..9e1f57814 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -89,6 +89,12 @@ message SidRqkmeans { // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, // spherical: false, seed: 1234}. optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + // Target number of embeddings to reservoir-sample for the FAISS fit + // (global, across all ranks). Bounds host memory regardless of corpus + // size. 0 (the default) auto-derives it as K * max_points_per_centroid + // — exactly what FAISS subsamples to internally (default 256), so no + // training points are wasted. + optional uint32 train_sample_size = 6 [default = 0]; // Name of the item embedding feature inside the input Batch. optional string embedding_feature_name = 40 [default = "item_emb"]; From 1ac30d38a4b083f2ba8d7e5c76eda5e68359ded5 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 06:09:18 +0000 Subject: [PATCH 017/129] [refactor] RQ-VAE kmeans_init: use FAISS K-Means + broadcast (drop torch Lloyd) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (kmeans.py:134 "Why not use Faiss KMeans instead?" and residual_quantized.py:188 "averaging KMeans centroids across ranks is meaningless"). - New kmeans.faiss_residual_kmeans(): FAISS residual K-Means warm-start, the same backend the offline RQ-KMeans fit uses. Replaces the torch-native _kmeans / _kmeans_plus_plus / _residual_kmeans (deleted — they were the O(K^2 N), non-deterministic, single-batch Lloyd path only init used). - ResidualVectorQuantizer.init_embed_ now fits on rank 0 only and dist.broadcast(src=0)s the codebook, so every rank starts identical — instead of all_reduce-averaging permutation-misaligned per-rank centroids. Tests: swapped the torch-kmeans unit tests for faiss_residual_kmeans + a kmeans_init=True seeding test; added a 2-rank NCCL test asserting the broadcast yields bit-identical codebooks across ranks (validated on 2x GPU). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid_generation/kmeans.py | 136 +++++------------- tzrec/modules/sid_generation/kmeans_test.py | 27 ++-- .../sid_generation/residual_quantizer_test.py | 18 +++ .../residual_vector_quantizer.py | 32 +++-- .../residual_vector_quantizer_dist_test.py | 91 ++++++++++++ 5 files changed, 183 insertions(+), 121 deletions(-) create mode 100644 tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py index f089e751e..e32c3cf87 100644 --- a/tzrec/modules/sid_generation/kmeans.py +++ b/tzrec/modules/sid_generation/kmeans.py @@ -18,14 +18,12 @@ :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. -* :func:`_kmeans` / :func:`_residual_kmeans` — pure-torch Lloyd's - K-Means + residual variant, used by :class:`ResidualVectorQuantizer` to - warm-start the RQ-VAE codebook on the first training batch. They run - once on a single batch of encoder outputs (typically ~2k × 64), so - pulling in FAISS here would be all overhead and no benefit. +* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by + :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the + first training batch (same FAISS backend as the offline RQ-KMeans fit). """ -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch import nn @@ -92,109 +90,53 @@ def _squared_euclidean_distance( @torch.no_grad() -def _kmeans_plus_plus( - data: torch.Tensor, - n_clusters: int, -) -> torch.Tensor: - """KMeans++ initialization (Arthur & Vassilvitskii 2007). - - Selects initial centroids via distance-weighted probability sampling - to ensure well-spread starting points. Used by the RQ-VAE codebook - init path (``ResidualVectorQuantizer.kmeans_init``); the K-Means backend itself no - longer needs it. - - Args: - data (Tensor): data points, shape (N, D). - n_clusters (int): number of clusters K. - - Returns: - Tensor: initial centroids, shape (K, D). - """ - N, D = data.shape - centroids = torch.zeros(n_clusters, D, device=data.device, dtype=data.dtype) - - idx = torch.randint(0, N, (1,), device=data.device) - centroids[0] = data[idx] - - for i in range(1, n_clusters): - dists = _squared_euclidean_distance(data, centroids[:i]) # (N, i) - min_dists = dists.min(dim=1)[0] # (N,) - if min_dists.sum() == 0: - centroids[i:] = data[ - torch.randint(0, N, (n_clusters - i,), device=data.device) - ] - break - next_idx = torch.multinomial(min_dists, num_samples=1) - centroids[i] = data[next_idx] - - return centroids - - -@torch.no_grad() -def _kmeans( - samples: torch.Tensor, - n_clusters: int, - n_iters: int = 100, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Lloyd's K-Means with KMeans++ initialization. - - Used by :class:`ResidualVectorQuantizer.init_embed_` to warm-start the - RQ-VAE codebook on the first training batch. - - Args: - samples (Tensor): data points, shape (N, D). - n_clusters (int): number of clusters K. - n_iters (int): number of Lloyd iterations. Default: 100. - - Returns: - centroids (Tensor): cluster centers, shape (K, D). - assignments (Tensor): cluster indices, shape (N,). - """ - N, D = samples.shape - centroids = _kmeans_plus_plus(samples, n_clusters) - - for _ in range(n_iters): - dists = _squared_euclidean_distance(samples, centroids) # (N, K) - assignments = dists.argmin(dim=-1) # (N,) - - bins = torch.bincount(assignments, minlength=n_clusters) - zero_mask = bins == 0 - bins_clamped = bins.masked_fill(zero_mask, 1) - - new_centroids = torch.zeros_like(centroids) - new_centroids.scatter_add_(0, assignments.unsqueeze(1).expand(-1, D), samples) - new_centroids = new_centroids / bins_clamped.unsqueeze(1) - - # Keep old centroids for empty clusters - centroids = torch.where(zero_mask.unsqueeze(1), centroids, new_centroids) - - return centroids, assignments - - -@torch.no_grad() -def _residual_kmeans( +def faiss_residual_kmeans( samples: torch.Tensor, n_clusters_list: List[int], - n_iters: int = 100, + faiss_kmeans_kwargs: Optional[Dict] = None, ) -> List[torch.Tensor]: - """Residual K-Means: per-layer cluster then subtract centroids. + """Residual K-Means warm-start via FAISS, one pass per layer. - Used by :class:`ResidualVectorQuantizer.init_embed_` to seed every RQ - codebook layer in one pass over the first training batch. + Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned + centroid, and repeats on the residual for every layer. Used by + :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook + from the first training batch — the same FAISS backend the offline + RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. Args: samples (Tensor): data points, shape (N, D). n_clusters_list (List[int]): per-layer cluster counts. - n_iters (int): K-Means iterations per layer. Default: 100. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). Returns: - List[Tensor]: per-layer centroids ``[(K0, D), (K1, D), ...]``. + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. """ - res_centers = [] + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQ-VAE kmeans_init. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + device = samples.device + _, D = samples.shape + # Own a contiguous fp32 numpy copy we mutate in place to form residuals. + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] for n_clusters in n_clusters_list: - centroids, assignments = _kmeans(samples, n_clusters, n_iters) - res_centers.append(centroids) - samples = samples - centroids[assignments] + kmeans = faiss.Kmeans(D, n_clusters, **kwargs) + kmeans.train(x) + centroids = kmeans.centroids.copy() # (K, D) + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = kmeans.index.search(x, 1) + x -= centroids[idx.ravel()] # residual, in place return res_centers diff --git a/tzrec/modules/sid_generation/kmeans_test.py b/tzrec/modules/sid_generation/kmeans_test.py index 531b33126..f99c70d70 100644 --- a/tzrec/modules/sid_generation/kmeans_test.py +++ b/tzrec/modules/sid_generation/kmeans_test.py @@ -15,15 +15,14 @@ from tzrec.modules.sid_generation.kmeans import ( KMeansLayer, - _kmeans, - _residual_kmeans, _squared_euclidean_distance, + faiss_residual_kmeans, recon_diagnostics, ) class KmeansHelpersTest(unittest.TestCase): - """Tests for the pure-torch K-Means helpers.""" + """Tests for the K-Means helper functions.""" def test_recon_diagnostics_zero_on_identity(self) -> None: x = torch.randn(8, 4) @@ -46,21 +45,21 @@ def test_squared_euclidean_distance_chunked_matches(self) -> None: chunked = _squared_euclidean_distance(x, y, chunk_size=16) torch.testing.assert_close(full, chunked) - def test_kmeans_shapes_and_assignment_range(self) -> None: + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") torch.manual_seed(0) - samples = torch.randn(200, 6) - centroids, assignments = _kmeans(samples, n_clusters=8, n_iters=5) - self.assertEqual(centroids.shape, (8, 6)) - self.assertEqual(assignments.shape, (200,)) - self.assertTrue((assignments >= 0).all() and (assignments < 8).all()) - - def test_residual_kmeans_per_layer_centers(self) -> None: - torch.manual_seed(0) - samples = torch.randn(200, 6) - centers = _residual_kmeans(samples, [8, 4], n_iters=5) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) self.assertEqual(len(centers), 2) self.assertEqual(centers[0].shape, (8, 6)) self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + self.assertEqual(centers[0].device, samples.device) class KMeansLayerTest(unittest.TestCase): diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index f5893cf4c..655a2f306 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -88,6 +88,24 @@ def test_get_codes_no_grad(self) -> None: codes = self.rvq.get_codes(torch.randn(4, 8)) self.assertEqual(codes.shape, (4, 3)) + def test_faiss_kmeans_init_seeds_codebook(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ) + self.assertFalse(bool(rvq.initted.item())) + rvq.train() + # First training forward triggers the FAISS warm-start. + rvq(torch.randn(512, 8)) + self.assertTrue(bool(rvq.initted.item())) + for layer in rvq.layers: + self.assertTrue(torch.isfinite(layer.embedding.weight).all()) + self.assertGreater(layer.embedding.weight.abs().sum().item(), 0.0) + class ResidualKMeansQuantizerTest(unittest.TestCase): def test_is_subclass(self) -> None: diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer.py b/tzrec/modules/sid_generation/residual_vector_quantizer.py index 18ed0a480..a534550e2 100644 --- a/tzrec/modules/sid_generation/residual_vector_quantizer.py +++ b/tzrec/modules/sid_generation/residual_vector_quantizer.py @@ -18,7 +18,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid_generation.kmeans import _residual_kmeans +from tzrec.modules.sid_generation.kmeans import faiss_residual_kmeans from tzrec.modules.sid_generation.residual_quantizer import ResidualQuantizer from tzrec.modules.sid_generation.types import ( QuantizeForwardMode, @@ -155,10 +155,15 @@ def __init__( @torch.jit.ignore @torch.no_grad() def init_embed_(self, data: torch.Tensor) -> None: - """Initialize codebook weights via residual K-Means. + """Initialize codebook weights via FAISS residual K-Means. Only executed once when kmeans_init=True and not yet initialized. - Uses the first batch of training data as initialization pool. + Uses the first batch of training data as the initialization pool. + + Under DDP the codebook is fit on rank 0 only and broadcast, so every + rank starts from the SAME codebook. (Averaging per-rank centroids — + the previous behavior — mixes permutation-misaligned clusters across + ranks and yields a near-random warm start.) Args: data (Tensor): input data, shape (B, D). @@ -166,14 +171,21 @@ def init_embed_(self, data: torch.Tensor) -> None: if self.initted: return - centers = _residual_kmeans(data, self.n_embed_list) - - # Average per-layer centroids across DDP ranks so every rank - # starts from the same codebook. - if dist.is_initialized() and dist.get_world_size() > 1: + is_ddp = dist.is_initialized() and dist.get_world_size() > 1 + if (not is_ddp) or dist.get_rank() == 0: + centers = faiss_residual_kmeans( + data, + self.n_embed_list, + {"niter": 10, "seed": 123, "verbose": False}, + ) + else: + centers = [ + torch.empty(k, self.embed_dim, dtype=torch.float32, device=data.device) + for k in self.n_embed_list + ] + if is_ddp: for c in centers: - dist.all_reduce(c, op=dist.ReduceOp.SUM) - c /= dist.get_world_size() + dist.broadcast(c, src=0) for i, layer in enumerate(self.layers): layer.embedding.weight.data.copy_(centers[i]) diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py b/tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py new file mode 100644 index 000000000..b36e182a4 --- /dev/null +++ b/tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-process test for ResidualVectorQuantizer FAISS kmeans-init. + +Validates the DDP path of ``init_embed_``: the codebook is fit on rank 0 +only and broadcast, so every rank ends with a bit-identical warm start. +(The previous behavior averaged permutation-misaligned per-rank centroids, +which the review flagged as a near-random init.) Uses NCCL on GPU when +>=2 devices are available, else gloo/CPU. +""" + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tzrec.modules.sid_generation.residual_vector_quantizer import ( + ResidualVectorQuantizer, +) +from tzrec.utils import misc_util + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _init_embed_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Rank-distinct data: a per-rank average/init would diverge; only a + # broadcast-from-rank0 init yields identical codebooks. + torch.manual_seed(rank) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ).to(device) + rvq.train() + rvq(torch.randn(512, 8, device=device)) # first forward triggers init_embed_ + assert bool(rvq.initted.item()), f"rank{rank}: not initialized" + + for layer in rvq.layers: + w = layer.embedding.weight.detach().clone() + wmin, wmax = w.clone(), w.clone() + dist.all_reduce(wmin, op=dist.ReduceOp.MIN) + dist.all_reduce(wmax, op=dist.ReduceOp.MAX) + assert torch.allclose(wmin, wmax), ( + f"rank{rank}: codebook differs across ranks (init not broadcast)" + ) + dist.destroy_process_group() + + +class ResidualVectorQuantizerDistTest(unittest.TestCase): + """2-rank test for the FAISS kmeans-init broadcast.""" + + def test_init_embed_broadcast(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_init_embed_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +if __name__ == "__main__": + unittest.main() From f4fe6d50b58be924ad72d314ae3bedb278ea7318 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 06:26:01 +0000 Subject: [PATCH 018/129] [feat] RQ-KMeans: support non-uniform codebooks (e.g. [256, 512, 1024]) Per request. FAISS itself supports any K per instance; only our offline wrapper was restricted because train_offline reused a single faiss.Kmeans across layers (hence the uniformity assert). Now: - ResidualKMeansQuantizer.train_offline builds a fresh faiss.Kmeans per layer with that layer's K (index construction is a cheap O(K*D) alloc next to train(), so effectively free); uniformity assert removed. - SidRqkmeans reservoir cap now derives from max(n_embed_list) so the largest layer is fed K*max_points_per_centroid points (non-uniform would otherwise under-sample the big layer). - proto + docstrings updated; RQ-KMeans now matches RQ-VAE, which already supported per-layer K. Tests: swapped the "rejected" assert for non-uniform support + a non-uniform train_offline fit, and an end-to-end SidRqkmeans [8,4,16] test asserting the cap uses max(K) and per-layer codes stay in range. Module tests pass locally; model paths validated on the remote (protobuf 4.x) below. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 4 ++- tzrec/models/sid_rqkmeans_test.py | 35 +++++++++++++++++-- .../residual_kmeans_quantizer.py | 27 ++++++-------- .../sid_generation/residual_quantizer_test.py | 27 ++++++++++++-- tzrec/protos/models/sid_model.proto | 8 ++--- 5 files changed, 75 insertions(+), 26 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index dd67c0bbf..899e96613 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -100,7 +100,9 @@ def __init__( # buffering the full corpus is wasted memory. We reservoir-sample to # that target instead, split across ranks so the gathered set on # rank0 is ~train_sample_size and FAISS does no further subsampling. - k = self._n_embed_list[0] + # Use the LARGEST per-layer K so non-uniform codebooks (e.g. + # [256, 512, 1024]) still feed their biggest layer enough points. + k = max(self._n_embed_list) max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) global_target = ( cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 03cd2c920..d7e347053 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -36,11 +36,11 @@ def _make_batch(batch_size: int, input_dim: int) -> Batch: class SidRqkmeansOfflineTest(unittest.TestCase): """Tests for SidRqkmeans (FAISS-only).""" - def _create_model(self, input_dim=32, n_layers=2, niter=5): + def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): """Create a SidRqkmeans configured for offline FAISS fit.""" from google.protobuf.struct_pb2 import Struct - n_embed_list = [16] * n_layers + n_embed_list = codebook if codebook is not None else [16] * n_layers faiss_kwargs = Struct() faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) @@ -137,6 +137,37 @@ def test_on_train_end_runs_faiss(self) -> None: self.assertEqual(codes.shape, (B, 2)) self.assertTrue((codes >= 0).all() and (codes < 16).all()) + def test_non_uniform_codebook_end_to_end(self) -> None: + """Non-uniform codebook [8, 4, 16]: fit then emit per-layer codes.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + codebook = [8, 4, 16] + model = self._create_model(input_dim=input_dim, codebook=codebook) + # Reservoir cap derives from the LARGEST K (16), not the first (8). + self.assertEqual( + model._sample_cap, + 16 * int(model._faiss_kwargs.get("max_points_per_centroid", 256)), + ) + + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + for k, layer in zip(codebook, model._quantizer.layers): + self.assertTrue(bool(layer._is_initialized.item())) + self.assertEqual(layer.centroids.shape[0], k) + + model.eval() + codes = model.predict(_make_batch(B, input_dim))["codes"] + self.assertEqual(codes.shape, (B, 3)) + for i, k in enumerate(codebook): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 2596993b4..92436160b 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -45,9 +45,9 @@ class ResidualKMeansQuantizer(ResidualQuantizer): embed_dim (int): feature dimension. n_layers (int): number of residual quantization layers. n_embed (int|List[int]): number of clusters per layer. Default: 256. - All layers must share the same ``K`` — a single FAISS ``Kmeans`` - object is reused across layers (matches the OneRec reference). - Non-uniform codebooks are not supported. + May differ per layer (non-uniform codebooks such as + ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a + separate ``faiss.Kmeans`` per layer. normalize_residuals (bool): whether to L2-normalize residuals before each layer. Default: False. faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to @@ -66,14 +66,6 @@ def __init__( super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) self.faiss_kmeans_kwargs = dict(faiss_kmeans_kwargs or {}) - # ``train_offline`` reuses a single ``faiss.Kmeans`` instance across - # layers, so non-uniform codebooks would silently train layers 1+ - # with ``K=n_embed_list[0]``. Fail fast instead. - assert len(set(self.n_embed_list)) == 1, ( - "ResidualKMeansQuantizer requires a uniform codebook size " - f"across layers; got {self.n_embed_list}." - ) - self.layers = nn.ModuleList( [ KMeansLayer( @@ -206,11 +198,6 @@ def train_offline( N, D = x.shape out = np.zeros((N, D), dtype=np.float32) - # Reuse one Kmeans instance across all layers (matches OneRec impl): - # rebuilding the FAISS object per layer doubles index-init cost. - n_embed = self.n_embed_list[0] - kmeans = faiss.Kmeans(self.embed_dim, n_embed, **self.faiss_kmeans_kwargs) - # Chunk size for index.search to limit peak memory. # 500K × 512 × 4B ≈ 1 GB per chunk. SEARCH_CHUNK = 500_000 @@ -221,6 +208,14 @@ def train_offline( np.maximum(norms, 1e-8, out=norms) x /= norms # in-place + # Fresh Kmeans per layer so each layer can use its own K + # (non-uniform codebooks supported). Index construction is a cheap + # O(K*D) allocation next to train(), so this is effectively free. + kmeans = faiss.Kmeans( + self.embed_dim, + self.n_embed_list[layer_idx], + **self.faiss_kmeans_kwargs, + ) kmeans.train(x) for start in range(0, N, SEARCH_CHUNK): diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index 655a2f306..6aa8358e5 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -112,9 +112,12 @@ def test_is_subclass(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) self.assertIsInstance(rkq, ResidualQuantizer) - def test_non_uniform_codebook_rejected(self) -> None: - with self.assertRaises(AssertionError): - ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=[8, 4]) + def test_non_uniform_codebook_supported(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) + self.assertEqual(rkq.n_embed_list, [8, 4, 16]) + self.assertEqual( + [layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16] + ) def test_forward_returns_zeros_before_fit(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) @@ -123,6 +126,24 @@ def test_forward_returns_zeros_before_fit(self) -> None: self.assertEqual(codes.shape, (5, 2)) self.assertEqual(quantized.shape, (5, 4)) + def test_train_offline_non_uniform(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + n_embed = [8, 4, 16] + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(512, 4), verbose=False) + self.assertTrue(rkq.all_initialized) + # Each layer fit its own K centroids; codes stay in per-layer range. + codes, _ = rkq(torch.randn(7, 4)) + self.assertEqual(codes.shape, (7, 3)) + for i, k in enumerate(n_embed): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + def test_train_offline_then_decode(self) -> None: try: import faiss # noqa: F401 diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 9e1f57814..8bb00d8fa 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -78,10 +78,10 @@ message SidRqkmeans { // no encoder). optional uint32 input_dim = 1 [default = 512]; // Per-layer cluster counts, e.g. [256, 256, 256]. - // List length is the number of residual quantization layers. All - // entries must be equal — the FAISS backend reuses a single - // ``faiss.Kmeans`` object across layers, so non-uniform codebooks - // are not supported (a uniformity assert fires at construction). + // List length is the number of residual quantization layers. Entries + // may differ per layer (non-uniform codebooks such as [256, 512, 1024] + // are supported — the FAISS backend fits a separate ``faiss.Kmeans`` + // per layer). repeated uint32 codebook = 3; // L2-normalize residuals before each layer. optional bool normalize_residuals = 4 [default = true]; From f9fdea1b568d71797db529eb3f4f94117c93b259 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 07:52:52 +0000 Subject: [PATCH 019/129] [perf] RQ-KMeans train_offline: feed FAISS torch tensors, GPU-accelerate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review (residual_kmeans.py:229 "faiss can directly accept a torch tensor, do not need to convert numpy"). - import faiss.contrib.torch_utils and pass torch tensors to Kmeans.train / index.search directly — no numpy round-trips; centroids/codes flow as torch tensors. - Auto-select FAISS GPU compute when a faiss-gpu build is present (gpu=current_device, overridable via faiss_kmeans_kwargs['gpu']); falls back to CPU on faiss-cpu builds. The residual matrix stays a host tensor — FAISS streams only its subsampled (k*max_ppc) working set to the GPU, so we never hold (N,D) in VRAM (no A10 OOM risk). Same code path both ways. Measured ~80x faster training on GPU (0.2s vs 16s, N=262k/k=1024/niter=10). CPU path validated locally; GPU path validated on the H20 (faiss-gpu) below. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../residual_kmeans_quantizer.py | 64 ++++++++++--------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 92436160b..50789329f 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -158,12 +158,17 @@ def train_offline( ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. + FAISS consumes torch tensors directly (via ``faiss.contrib. + torch_utils``) — no numpy round-trips. The residual matrix stays a + host (CPU) tensor; when a faiss-gpu build is present, ``gpu=`` + moves only FAISS's internal, subsampled working set to the GPU, so we + never hold (N, D) in VRAM. On a faiss-cpu build it runs on CPU + unchanged. Either way the code path is identical. + Args: - inputs: full embedding matrix, shape (N, D). Either a - ``torch.Tensor`` (will be copied to numpy) or a - ``np.ndarray`` (ownership transferred; caller MUST - release any outside reference — the array is mutated - in-place to compute residuals layer by layer). + inputs: full embedding matrix, shape (N, D), ``torch.Tensor`` or + ``np.ndarray``. Copied once to an owned CPU float32 tensor; + the caller's input is not mutated. verbose (bool): whether to print per-layer reconstruction loss. Default: True. @@ -172,31 +177,38 @@ def train_offline( """ try: import faiss + import faiss.contrib.torch_utils # noqa: F401 (torch tensor I/O) except ImportError as e: raise ImportError( "faiss is required for ResidualKMeansQuantizer training. Install via " "`pip install faiss-cpu` or `pip install faiss-gpu`." ) from e - # Materialise to a float32 contiguous numpy array that we own - # (so in-place residual updates are safe). + # Own a contiguous CPU float32 tensor we can update in place for + # residuals, without mutating the caller's input. if isinstance(inputs, torch.Tensor): assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # Tensor path still requires a copy; caller will hold a - # reference until we return, so we must not alias it. - x = inputs.detach().cpu().float().numpy().copy() + x = inputs.detach().to("cpu", torch.float32).contiguous().clone() else: assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # Numpy path: take ownership — no extra copy. Caller promises - # the array is no longer used outside. Only ensure dtype - # + contiguity (zero-copy when already satisfied). - x = np.ascontiguousarray(inputs, dtype=np.float32) - N, D = x.shape - out = np.zeros((N, D), dtype=np.float32) + x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() + N = x.shape[0] + out = torch.zeros_like(x) + + # Use FAISS GPU compute when a GPU build is available (data stays on + # host; FAISS streams only its subsampled training set to the device). + # An explicit ``gpu`` in faiss_kmeans_kwargs always wins. + kwargs = dict(self.faiss_kmeans_kwargs) + if "gpu" not in kwargs: + kwargs["gpu"] = ( + torch.cuda.current_device() + if faiss.get_num_gpus() > 0 and torch.cuda.is_available() + else False + ) # Chunk size for index.search to limit peak memory. # 500K × 512 × 4B ≈ 1 GB per chunk. @@ -204,40 +216,34 @@ def train_offline( for layer_idx in range(self.n_layers): if self.normalize_residuals: - norms = np.linalg.norm(x, axis=1, keepdims=True) - np.maximum(norms, 1e-8, out=norms) - x /= norms # in-place + x = F.normalize(x, dim=-1) # Fresh Kmeans per layer so each layer can use its own K # (non-uniform codebooks supported). Index construction is a cheap # O(K*D) allocation next to train(), so this is effectively free. kmeans = faiss.Kmeans( - self.embed_dim, - self.n_embed_list[layer_idx], - **self.faiss_kmeans_kwargs, + self.embed_dim, self.n_embed_list[layer_idx], **kwargs ) kmeans.train(x) + centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32).cpu() for start in range(0, N, SEARCH_CHUNK): end = min(start + SEARCH_CHUNK, N) _, idx = kmeans.index.search(x[start:end], 1) - q = kmeans.centroids[idx.ravel()] # (chunk, D) + idx = torch.as_tensor(idx, device="cpu").reshape(-1).long() + q = centroids[idx] # (chunk, D) out[start:end] += q x[start:end] -= q # residual del idx, q if verbose: - out_t = torch.from_numpy(out) - ref_t = torch.from_numpy(out + x) # x_in = out + residual logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, - self._calc_loss(ref_t, out_t), + self._calc_loss(out + x, out), # x_in = out + residual ) - del out_t, ref_t - centroids_t = torch.from_numpy(kmeans.centroids.copy()) - self.layers[layer_idx].load_centroids_(centroids_t) + self.layers[layer_idx].load_centroids_(centroids) if verbose: logger.info( "[ResidualKMeansQuantizer][offline_faiss] layer %d finished", From b300adbe952bc2bc336631b540e6ae7db7c9ff55 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 12:25:42 +0000 Subject: [PATCH 020/129] [refactor] SID: hoist shared input_dim / normalize_residuals into BaseSidModel Both SID proto messages carry input_dim and normalize_residuals, and both models re-read them. Move the parsing into BaseSidModel.__init__ (alongside the already-shared embedding_feature_name and codebook), exposing self._input_dim / self._normalize_residuals. SidRqvae and SidRqkmeans now use the base attributes instead of re-reading cfg. No behavior change; sid_rqvae tests pass locally, full suite validated on the remote below. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 10 +++++++++- tzrec/models/sid_rqkmeans.py | 5 ++--- tzrec/models/sid_rqvae.py | 4 ++-- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index e48827a63..35ef27040 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -28,8 +28,11 @@ class BaseSidModel(BaseModel): Factors the structure common to :class:`SidRqvae` (RQ-VAE) and :class:`SidRqkmeans` (residual K-Means): + - the shared config fields every SID proto carries — + ``embedding_feature_name`` (``_embedding_feature_name``), ``input_dim`` + (``_input_dim``), ``normalize_residuals`` (``_normalize_residuals``), + and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``), - reading the item-embedding feature out of ``Batch.dense_features``, - - parsing the per-layer ``codebook`` into ``n_embed_list`` / ``n_layers``, - the eval metrics every SID model reports — reconstruction ``mse`` and ``unique_sid_ratio`` (codebook coverage). @@ -56,7 +59,12 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config + # Config fields shared by every SID model (present on each SID proto + # message): the item-embedding feature, the input dimension, the + # residual-normalization toggle, and the per-layer codebook. self._embedding_feature_name = cfg.embedding_feature_name + self._input_dim = cfg.input_dim + self._normalize_residuals = cfg.normalize_residuals assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" self._n_embed_list = list(cfg.codebook) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 899e96613..b1a1b07f1 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -86,12 +86,11 @@ def __init__( else {} ) - self._input_dim = cfg.input_dim self._quantizer = ResidualKMeansQuantizer( - embed_dim=cfg.input_dim, + embed_dim=self._input_dim, n_layers=self._n_layers, n_embed=self._n_embed_list, - normalize_residuals=cfg.normalize_residuals, + normalize_residuals=self._normalize_residuals, faiss_kmeans_kwargs=self._faiss_kwargs, ) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 1478bd268..125ea0211 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -87,7 +87,7 @@ def __init__( cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None ) - input_dim = cfg.input_dim + input_dim = self._input_dim # shared field parsed by BaseSidModel embed_dim = cfg.embed_dim hidden_dims = list(cfg.hidden_dims) if cfg.hidden_dims else [input_dim // 2] # latent_weight defaults to (1.0, 0.5) when the user leaves the @@ -111,7 +111,7 @@ def __init__( n_layers=self._n_layers, n_embed=self._n_embed_list, forward_mode=cfg.forward_mode, - normalize_residuals=cfg.normalize_residuals, + normalize_residuals=self._normalize_residuals, distance_type=cfg.distance_type, commitment_loss=cfg.commitment_loss, latent_weight=latent_weight, From 345f1338f9bc962dc6fa7353d9a096baaa8eb1e9 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 12:28:39 +0000 Subject: [PATCH 021/129] [chore] SidRqvae: use self._input_dim directly, drop redundant local alias MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to hoisting input_dim into BaseSidModel — there's no need for a local `input_dim = self._input_dim` alias; reference the base attribute directly in the encoder/decoder dims and init log. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 125ea0211..aeacd205f 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -87,9 +87,10 @@ def __init__( cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None ) - input_dim = self._input_dim # shared field parsed by BaseSidModel embed_dim = cfg.embed_dim - hidden_dims = list(cfg.hidden_dims) if cfg.hidden_dims else [input_dim // 2] + hidden_dims = ( + list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] + ) # latent_weight defaults to (1.0, 0.5) when the user leaves the # repeated field empty. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) @@ -102,9 +103,11 @@ def __init__( sinkhorn_iters = cfg.sinkhorn_config.iters sinkhorn_epsilon = cfg.sinkhorn_config.epsilon - self._encoder = self._build_mlp([input_dim, *hidden_dims, embed_dim]) + self._encoder = self._build_mlp([self._input_dim, *hidden_dims, embed_dim]) # Decoder is the symmetric reverse of the encoder. - self._decoder = self._build_mlp([embed_dim, *reversed(hidden_dims), input_dim]) + self._decoder = self._build_mlp( + [embed_dim, *reversed(hidden_dims), self._input_dim] + ) self._quantizer = ResidualVectorQuantizer( embed_dim=embed_dim, @@ -132,7 +135,7 @@ def __init__( logger.info( "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " "n_layers=%d, n_embed=%s, loss_type=%s, use_clip=%s", - input_dim, + self._input_dim, embed_dim, hidden_dims, self._n_layers, From 0fbff37d13529836a16ce96346bf3a10af1693f0 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 1 Jun 2026 12:49:50 +0000 Subject: [PATCH 022/129] [fix] SID: wire on_train_end lifecycle hook into BaseModel + main.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The SID port brought the models/modules/protos but not the lifecycle wiring, so SidRqkmeans.on_train_end (which runs the FAISS fit) was never invoked by a real train_eval run — only the unit tests called it directly. A real run would finish with an unfit (zero) codebook and predict would emit all-zero codes. - BaseModel.on_train_end: add the no-op base hook (so every model has it). - main.py: call _model.on_train_end() after the train loop, and force the tail-save to fire afterwards (last_ckpt_step guard) so the post-hook state — e.g. the freshly fit FAISS codebook — is always persisted, even when the last in-loop checkpoint coincided with the final step. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 9 +++++++++ tzrec/models/model.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/tzrec/main.py b/tzrec/main.py index 87f2984fb..8824e8373 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -500,6 +500,15 @@ def _train_and_evaluate( if lr.by_epoch: lr.step() + # One-shot end-of-loop hook (default no-op). Some models do real work + # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings + # collected during training. Since that mutates model state, force the + # tail-save below to fire so the post-hook state is persisted even when + # the last in-loop checkpoint coincided with the final step. + _model.on_train_end() + if last_ckpt_step == i_step: + last_ckpt_step = -1 + _log_train( i_step, losses, diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 40da5335a..10fa8aae5 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,6 +150,15 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results + def on_train_end(self) -> None: + """Hook fired once after the train_eval loop exits. + + Default: no-op. Override in models that need one-shot end-of-loop + work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS + codebook from the embedding sample it collected during training. + """ + pass + def sparse_parameters( self, ) -> Tuple[Iterable[nn.Parameter], Iterable[nn.Parameter]]: From 0441ff23c033c90a846c01d73191e7556dce0675 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 03:33:22 +0000 Subject: [PATCH 023/129] [refactor] SidRqvae: merge _recon_loss and _masked_recon_loss Both computed the same per-loss_type reconstruction loss (mse/l1/cosine); the only difference was the reduction. Fold into a single _recon_loss with an optional per-row mask: no mask -> mean over all rows (== the old reduction="mean"); mask -> mean over the masked-in rows (the mixed recon+CLIP path). _forward_mixed now calls _recon_loss(..., recon_mask). Behavior unchanged; 12/12 sid_rqvae tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 52 +++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index aeacd205f..c17821eaa 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -144,8 +144,6 @@ def __init__( self._use_clip, ) - # ----- encode / decode / loss helpers (formerly RQVAE) ----- - def _encode(self, x: torch.Tensor) -> torch.Tensor: """Encode. (B, input_dim) -> (B, embed_dim).""" return self._encoder(x) @@ -154,14 +152,34 @@ def _decode(self, z_q: torch.Tensor) -> torch.Tensor: """Decode. (B, embed_dim) -> (B, input_dim).""" return self._decoder(z_q) - def _recon_loss(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """Mean reconstruction loss for the configured ``loss_type``.""" + def _recon_loss( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Reconstruction loss for the configured ``loss_type``. + + Returns the mean over all rows, or — when ``mask`` (a per-row bool) + is given — the mean over only the masked-in rows (the mixed + recon+CLIP path applies recon loss to recon rows only). No + data-dependent branching, so it stays ``torch.compile``-friendly. + + Args: + x_hat (Tensor): reconstructed output, shape (B, D). + x (Tensor): original input, shape (B, D). + mask (Tensor, optional): per-row bool; rows to include. + """ if self._loss_type == "mse": - return F.mse_loss(x_hat, x, reduction="mean") + per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) elif self._loss_type == "l1": - return F.l1_loss(x_hat, x, reduction="mean") + per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) else: # 'cosine' - return (1 - F.cosine_similarity(x_hat, x, dim=1)).mean() + per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) + if mask is None: + return per_sample.mean() + mask = mask.float() + return (per_sample * mask).sum() / mask.sum().clamp(min=1) def _forward_rqvae( self, x: torch.Tensor, temperature: float = 1.0 @@ -182,22 +200,6 @@ def _forward_rqvae( "loss": recon_loss + quant_loss, } - def _masked_recon_loss( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - recon_mask: torch.Tensor, - ) -> torch.Tensor: - """Per-sample recon loss masked to recon rows (no data-dependent branch).""" - if self._loss_type == "mse": - per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif self._loss_type == "l1": - per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - else: # 'cosine' - per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) - n_recon = recon_mask.float().sum().clamp(min=1) - return (per_sample * recon_mask.float()).sum() / n_recon - def _forward_mixed( self, fea1: torch.Tensor, @@ -215,7 +217,7 @@ def _forward_mixed( x_hat2 = self._decode(quant2.quantized_embeddings) recon_mask = ~clip_mask - recon_loss = self._masked_recon_loss(x_hat1, fea1, recon_mask) + recon_loss = self._recon_loss(x_hat1, fea1, recon_mask) features = { "image_embed": x_hat1, @@ -238,8 +240,6 @@ def _forward_mixed( "commitment_loss": commitment, } - # ----- BaseModel interface ----- - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. From aa8110da7fcac22744c82b018f25ba6454ee73fe Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 03:40:01 +0000 Subject: [PATCH 024/129] [refactor] SidRqvae: unify loss keys across recon + CLIP paths The standard and mixed-CLIP paths emitted the same two losses under different dict keys: reconstruction_loss/recon_loss (the reconstruction loss) and quantization_loss/commitment_loss (the RVQ commitment loss). That also made the logged metric names differ by mode. Standardize on reconstruction_loss + quantization_loss everywhere (matches the quantizer's ResidualQuantizerOutput.quantization_loss field). loss() now always emits reconstruction_loss + quantization_loss, plus clip_loss when use_clip. Left the commitment_loss= constructor arg (the loss-type knob) and the _recon_loss method name untouched. Tests updated; 12/12 pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 15 ++++++--------- tzrec/models/sid_rqvae_test.py | 22 +++++++++++----------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index c17821eaa..94531d31c 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -235,9 +235,9 @@ def _forward_mixed( "codes": quant1.cluster_ids, "quantized": quant1.quantized_embeddings, "x_hat": x_hat1, - "recon_loss": recon_loss, + "reconstruction_loss": recon_loss, "clip_loss": clip_result["clip_loss"], - "commitment_loss": commitment, + "quantization_loss": commitment, } def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: @@ -293,9 +293,9 @@ def _predict_mixed( "codes": result["codes"], "quantized": result["quantized"], "x_hat": result["x_hat"], - "recon_loss": result["recon_loss"], + "reconstruction_loss": result["reconstruction_loss"], "clip_loss": result["clip_loss"], - "commitment_loss": result["commitment_loss"], + "quantization_loss": result["quantization_loss"], } return predictions @@ -312,13 +312,10 @@ def loss( losses (dict): a dict of loss tensor. """ losses: Dict[str, torch.Tensor] = {} + losses["reconstruction_loss"] = predictions["reconstruction_loss"] + losses["quantization_loss"] = predictions["quantization_loss"] if self._use_clip: - losses["recon_loss"] = predictions["recon_loss"] losses["clip_loss"] = predictions["clip_loss"] - losses["commitment_loss"] = predictions["commitment_loss"] - else: - losses["reconstruction_loss"] = predictions["reconstruction_loss"] - losses["quantization_loss"] = predictions["quantization_loss"] return losses def init_metric(self) -> None: diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 686a9a720..56542dd33 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -168,19 +168,19 @@ def test_rqvae_clip_mode(self) -> None: predictions = model.predict(batch) - # Mixed mode should return recon_loss, clip_loss, commitment_loss + # Mixed mode returns reconstruction_loss, clip_loss, quantization_loss self.assertIn("codes", predictions) - self.assertIn("recon_loss", predictions) + self.assertIn("reconstruction_loss", predictions) self.assertIn("clip_loss", predictions) - self.assertIn("commitment_loss", predictions) + self.assertIn("quantization_loss", predictions) self.assertIn("x_hat", predictions) self.assertEqual(predictions["codes"].shape[0], B) # Loss should return all three losses = model.loss(predictions, batch) - self.assertIn("recon_loss", losses) + self.assertIn("reconstruction_loss", losses) self.assertIn("clip_loss", losses) - self.assertIn("commitment_loss", losses) + self.assertIn("quantization_loss", losses) total_loss = sum(losses.values()) self.assertTrue(total_loss.requires_grad) @@ -220,8 +220,8 @@ def test_rqvae_clip_all_recon(self) -> None: # clip_loss should be 0 (no clip rows) self.assertEqual(predictions["clip_loss"].item(), 0.0) - # recon_loss should be > 0 - self.assertGreater(predictions["recon_loss"].item(), 0.0) + # reconstruction_loss should be > 0 + self.assertGreater(predictions["reconstruction_loss"].item(), 0.0) def test_rqvae_clip_all_clip(self) -> None: """Test mixed mode with all-clip batch (edge case).""" @@ -249,8 +249,8 @@ def test_rqvae_clip_all_clip(self) -> None: predictions = model.predict(batch) model.loss(predictions, batch) - # recon_loss should be 0 (no recon rows) - self.assertEqual(predictions["recon_loss"].item(), 0.0) + # reconstruction_loss should be 0 (no recon rows) + self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) # clip_loss should be > 0 self.assertGreater(predictions["clip_loss"].item(), 0.0) @@ -301,8 +301,8 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: ) predictions = model.predict(batch) - # All rows flagged as clip -> recon_loss should be 0, clip_loss > 0 - self.assertEqual(predictions["recon_loss"].item(), 0.0) + # All rows flagged as clip -> reconstruction_loss should be 0, clip_loss > 0 + self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) self.assertGreater(predictions["clip_loss"].item(), 0.0) def test_commitment_loss_l1_branch(self) -> None: From 5aae8ff77531eff29f9059fc0eadc849ebb129a0 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 06:20:08 +0000 Subject: [PATCH 025/129] [test] SID: merge sid_rqkmeans_dist_test into sid_rqkmeans_test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The two files tested the same model at different scales — single-process (proto parse, reservoir, FAISS fit, checkpoints, non-uniform) vs a 2-rank multi-process on_train_end DDP path. They duplicated _make_batch and the model-config builder. Merge into one file (matching the tzrec convention of co-locating dist tests, e.g. checkpoint_util_test): - Shared module-level helpers: _make_batch(..., device="cpu") and _build_model(...); the unit class's _create_model now wraps _build_model + init_parameters, and the spawned DDP worker reuses _build_model. - SidRqkmeansOfflineTest (single-process) + SidRqkmeansDistTest (2-rank, NCCL on GPU else gloo) now live together; deleted sid_rqkmeans_dist_test.py. Logic unchanged from the previously remote-validated tests (8 unit + 1 DDP). Structure verified locally (imports, both classes, picklable worker); full run needs a protobuf-4.x env (config_to_kwargs), as before. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans_dist_test.py | 139 ------------------------- tzrec/models/sid_rqkmeans_test.py | 131 ++++++++++++++++++----- 2 files changed, 106 insertions(+), 164 deletions(-) delete mode 100644 tzrec/models/sid_rqkmeans_dist_test.py diff --git a/tzrec/models/sid_rqkmeans_dist_test.py b/tzrec/models/sid_rqkmeans_dist_test.py deleted file mode 100644 index 82a96f2ac..000000000 --- a/tzrec/models/sid_rqkmeans_dist_test.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-process tests for SidRqkmeans.on_train_end's DDP code path. - -This exercises the collective sequence the single-process unit test -cannot reach: the cross-rank empty-buffer all_reduce, ``gather_object`` -of the per-rank embedding buffers to rank 0, the FAISS fit, and the -``broadcast`` of centroids + ``_is_initialized`` fill on every rank. - -Uses NCCL on GPU when >=2 devices are available (the production backend -the reviewer flagged for ``gather_object``), else gloo/CPU. -""" - -import os -import unittest - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torchrec import KeyedTensor - -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch -from tzrec.protos import model_pb2 -from tzrec.protos.models import sid_model_pb2 -from tzrec.utils import misc_util - -WORLD_SIZE = 2 - - -def _init(rank: int, world_size: int, port: int) -> torch.device: - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size - if use_cuda: - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return torch.device(f"cuda:{rank}") - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - return torch.device("cpu") - - -def _make_batch(batch_size: int, input_dim: int, device: torch.device) -> Batch: - dense = KeyedTensor.from_tensor_list( - keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim, device=device)] - ) - return Batch( - dense_features={BASE_DATA_GROUP: dense}, sparse_features={}, labels={} - ) - - -def _create_model(input_dim: int, n_layers: int, k: int): - from google.protobuf.struct_pb2 import Struct - - from tzrec.models.sid_rqkmeans import SidRqkmeans - - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": 5, "verbose": False, "seed": 1234}) - cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, - codebook=[k] * n_layers, - normalize_residuals=False, - faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", - ) - model_config = model_pb2.ModelConfig(sid_rqkmeans=cfg) - return SidRqkmeans(model_config=model_config, features=[], labels=[]) - - -def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: - device = _init(rank, world_size, port) - input_dim, n_layers, k = 16, 2, 16 - model = _create_model(input_dim, n_layers, k).to(device) - model.train() - - torch.manual_seed(100 + rank) - for _ in range(6): - model.predict(_make_batch(32, input_dim, device)) - assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" - - # The collective sequence under test: empty-flag all_reduce -> - # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. - model.on_train_end() - - # Every rank must end initialized with non-zero centroids. - for layer in model._quantizer.layers: - assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" - assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" - - # Centroids were broadcast from rank0 -> must be bit-identical across - # ranks (min == max under all_reduce). - for layer in model._quantizer.layers: - cmin = layer.centroids.clone() - cmax = layer.centroids.clone() - dist.all_reduce(cmin, op=dist.ReduceOp.MIN) - dist.all_reduce(cmax, op=dist.ReduceOp.MAX) - assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" - - # After the fit, eval predict emits valid codes. - model.eval() - codes = model.predict(_make_batch(8, input_dim, device))["codes"] - assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" - assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" - dist.destroy_process_group() - - -def _run(target) -> None: - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join() - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") - - -class SidRqkmeansDistTest(unittest.TestCase): - """2-rank test for SidRqkmeans.on_train_end.""" - - def test_on_train_end_ddp(self) -> None: - _run(_on_train_end_worker) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index d7e347053..8b224afac 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -9,22 +9,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import torch +import torch.distributed as dist +import torch.multiprocessing as mp from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.models.sid_rqkmeans import SidRqkmeans from tzrec.protos import model_pb2 from tzrec.protos.models import sid_model_pb2 +from tzrec.utils import misc_util from tzrec.utils.state_dict_util import init_parameters +WORLD_SIZE = 2 -def _make_batch(batch_size: int, input_dim: int) -> Batch: + +def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: """Create a minimal Batch with dense embedding features.""" dense_feature = KeyedTensor.from_tensor_list( - keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim)] + keys=["item_emb"], + tensors=[torch.randn(batch_size, input_dim, device=device)], ) return Batch( dense_features={BASE_DATA_GROUP: dense_feature}, @@ -33,32 +40,39 @@ def _make_batch(batch_size: int, input_dim: int) -> Batch: ) -class SidRqkmeansOfflineTest(unittest.TestCase): - """Tests for SidRqkmeans (FAISS-only).""" - - def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): - """Create a SidRqkmeans configured for offline FAISS fit.""" - from google.protobuf.struct_pb2 import Struct +def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmeans: + """Build a SidRqkmeans configured for offline FAISS fit. + + Module-level (not a method) so the spawned DDP workers below can build + the same model; callers move it to a device / init params as needed. + SID models read the item-embedding dense feature directly from the batch + and do not consume feature_groups, so none is set. + """ + from google.protobuf.struct_pb2 import Struct + + n_embed_list = codebook if codebook is not None else [16] * n_layers + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_list, + normalize_residuals=False, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + ) + return SidRqkmeans( + model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), + features=[], + labels=[], + ) - n_embed_list = codebook if codebook is not None else [16] * n_layers - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) +class SidRqkmeansOfflineTest(unittest.TestCase): + """Single-process tests for SidRqkmeans (FAISS-only).""" - sid_rqkmeans_cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, - codebook=n_embed_list, - normalize_residuals=False, - faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", - ) - # SID models read the item-embedding dense feature directly from the - # batch; they do not consume feature_groups, so none is set (which - # keeps the config consistent with the empty ``features`` list). - model_config = model_pb2.ModelConfig( - sid_rqkmeans=sid_rqkmeans_cfg, - ) - model = SidRqkmeans(model_config=model_config, features=[], labels=[]) + def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): + """Create a SidRqkmeans on CPU with params initialized.""" + model = _build_model(input_dim, n_layers, niter, codebook) init_parameters(model, device=torch.device("cpu")) return model @@ -221,5 +235,72 @@ def test_mid_fit_checkpoint_rejected_on_load(self) -> None: fresh.load_state_dict(sd) +# -------------------------------------------------------------------------- +# Distributed (multi-process) test for the DDP on_train_end path: the +# cross-rank gather_object -> FAISS fit -> broadcast sequence the in-process +# tests above cannot reach. NCCL on GPU when >=2 devices, else gloo/CPU. +# -------------------------------------------------------------------------- +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: + device = _init_dist(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) + model.train() + + torch.manual_seed(100 + rank) + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" + + # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. + model.on_train_end() + + for layer in model._quantizer.layers: + assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" + assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" + # Centroids were broadcast from rank0 -> must be bit-identical across ranks. + for layer in model._quantizer.layers: + cmin, cmax = layer.centroids.clone(), layer.centroids.clone() + dist.all_reduce(cmin, op=dist.ReduceOp.MIN) + dist.all_reduce(cmax, op=dist.ReduceOp.MAX) + assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" + + model.eval() + codes = model.predict(_make_batch(8, input_dim, device))["codes"] + assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" + assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" + dist.destroy_process_group() + + +class SidRqkmeansDistTest(unittest.TestCase): + """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" + + def test_on_train_end_ddp(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_on_train_end_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + if __name__ == "__main__": unittest.main() From d67f923994199d5fc77157046ef20c6d56db34b0 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 06:24:17 +0000 Subject: [PATCH 026/129] [chore] RQ-KMeans: drop unused all_initialized property MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The all_initialized property had no production caller — forward() checks each layer's is_initialized individually, and SidRqkmeans never used it. It was referenced only by residual_quantizer_test. Removed it; the tests now check all(layer.is_initialized for layer in rkq.layers) inline. 15/15 residual_quantizer tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid_generation/residual_kmeans_quantizer.py | 5 ----- tzrec/modules/sid_generation/residual_quantizer_test.py | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 50789329f..0de529a43 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -76,11 +76,6 @@ def __init__( ] ) - @property - def all_initialized(self) -> bool: - """Whether all layers have been initialized via offline FAISS.""" - return all(layer.is_initialized for layer in self.layers) - def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Assign codes per layer and sum the centroids. diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index 6aa8358e5..fc13a0c62 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -121,7 +121,7 @@ def test_non_uniform_codebook_supported(self) -> None: def test_forward_returns_zeros_before_fit(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertFalse(rkq.all_initialized) + self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) codes, quantized = rkq(torch.randn(5, 4)) self.assertEqual(codes.shape, (5, 2)) self.assertEqual(quantized.shape, (5, 4)) @@ -137,7 +137,7 @@ def test_train_offline_non_uniform(self) -> None: embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} ) rkq.train_offline(torch.randn(512, 4), verbose=False) - self.assertTrue(rkq.all_initialized) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) # Each layer fit its own K centroids; codes stay in per-layer range. codes, _ = rkq(torch.randn(7, 4)) self.assertEqual(codes.shape, (7, 3)) @@ -154,7 +154,7 @@ def test_train_offline_then_decode(self) -> None: embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} ) rkq.train_offline(torch.randn(256, 4), verbose=False) - self.assertTrue(rkq.all_initialized) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) codes, _ = rkq(torch.randn(5, 4)) self.assertTrue((codes >= 0).all() and (codes < 8).all()) From 3856bbc50bb64fa918a10cab2077262b5124b085 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 06:46:01 +0000 Subject: [PATCH 027/129] [bugfix] KMeans predict: drop data-dependent chunking so forward is FX-traceable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit torchrec's inference pipeline symbolically traces the model. The `if N <= chunk_size:` branch in `_squared_euclidean_distance` keyed off the traced batch dim, raising `torch.fx.proxy.TraceError` during predict export. The chunked path was only reachable per-batch from `KMeansLayer.predict` (small N — the offline fit uses FAISS, not this function), so the chunking was unnecessary as well as FX-breaking. Simplify to a branch-free (x_sq + y_sq - 2 x@y.T).clamp(min=0). Drop the now-dead chunk_size param and its test; add an FX symbolic-trace regression test on ResidualKMeansQuantizer.forward. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid_generation/kmeans.py | 30 +++++++------------ tzrec/modules/sid_generation/kmeans_test.py | 7 ----- .../sid_generation/residual_quantizer_test.py | 19 ++++++++++++ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid_generation/kmeans.py index e32c3cf87..0b6fe4255 100644 --- a/tzrec/modules/sid_generation/kmeans.py +++ b/tzrec/modules/sid_generation/kmeans.py @@ -57,36 +57,26 @@ def recon_diagnostics( @torch.no_grad() -def _squared_euclidean_distance( - x: torch.Tensor, - y: torch.Tensor, - chunk_size: int = 50000, -) -> torch.Tensor: - """Squared L2 distance with chunked computation for memory efficiency. - - Chunks the rows of ``x`` so peak memory is bounded by - ``chunk_size * K * 4 bytes`` (fp32) regardless of ``N``. +def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Squared L2 distance between rows of ``x`` and ``y``. Args: x (Tensor): data points, shape (N, D). y (Tensor): centroids, shape (K, D). - chunk_size (int): max rows of x per chunk. Default: 50000. Returns: Tensor: squared distances, shape (N, K). + + Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch + size and the full (N, K) product is small. Kept branch-free (no + data-dependent chunking on ``N``) so the predict forward stays + FX-traceable: torchrec's inference pipeline symbolically traces the + model, and a ``if N <= chunk_size`` on the traced batch dim raises a + ``torch.fx`` TraceError. """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - N = x.shape[0] - if N <= chunk_size: - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) - out = x.new_empty(N, y.shape[0]) - for start in range(0, N, chunk_size): - end = min(start + chunk_size, N) - out[start:end] = (x_sq[start:end] + y_sq - 2.0 * x[start:end] @ y.t()).clamp_( - min=0.0 - ) - return out + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) @torch.no_grad() diff --git a/tzrec/modules/sid_generation/kmeans_test.py b/tzrec/modules/sid_generation/kmeans_test.py index f99c70d70..52d685b0e 100644 --- a/tzrec/modules/sid_generation/kmeans_test.py +++ b/tzrec/modules/sid_generation/kmeans_test.py @@ -38,13 +38,6 @@ def test_squared_euclidean_distance(self) -> None: # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - def test_squared_euclidean_distance_chunked_matches(self) -> None: - x = torch.randn(120, 5) - y = torch.randn(7, 5) - full = _squared_euclidean_distance(x, y, chunk_size=1000) - chunked = _squared_euclidean_distance(x, y, chunk_size=16) - torch.testing.assert_close(full, chunked) - def test_faiss_residual_kmeans_per_layer_centers(self) -> None: try: import faiss # noqa: F401 diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index fc13a0c62..e26a7033b 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -126,6 +126,25 @@ def test_forward_returns_zeros_before_fit(self) -> None: self.assertEqual(codes.shape, (5, 2)) self.assertEqual(quantized.shape, (5, 4)) + def test_forward_is_fx_traceable(self) -> None: + """Predict forward must FX-trace. + + torchrec's inference pipeline symbolically traces the model, so the + per-batch distance path must be free of data-dependent control flow. + """ + import torch.fx as fx + + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + for layer in rkq.layers: # populate centroids -> is_initialized=True + layer.load_centroids_(torch.randn(8, 4)) + traced = fx.symbolic_trace(rkq) + x = torch.randn(5, 4) + c_eager, q_eager = rkq(x) + c_traced, q_traced = traced(x) + torch.testing.assert_close(c_traced, c_eager) + torch.testing.assert_close(q_traced, q_eager) + def test_train_offline_non_uniform(self) -> None: try: import faiss # noqa: F401 From 39c88a4ed0a519f331593e21510008d8f32e120f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 2 Jun 2026 08:55:56 +0000 Subject: [PATCH 028/129] [guard] VectorQuantize: reject use_sinkhorn + GUMBEL_SOFTMAX combo Sinkhorn and Gumbel-Softmax pick the code by two different rules: with Sinkhorn on, `ids = Q.argmax` (balanced optimal-transport assignment), while the Gumbel branch builds `emb` from argmax(-distances + noise) (nearest code). The two indices generally diverge, so the saved semantic ID would not match the codebook vector actually reconstructed and trained. STE has no such issue since it looks up embedding(ids) directly. Add a constructor assert forbidding the inconsistent combo (STE+Sinkhorn or Gumbel-without-Sinkhorn remain valid), retarget the gumbel test param to use_sinkhorn=False, and add a test that the rejected combo raises. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid_generation/vector_quantize.py | 13 +++++++++++++ .../modules/sid_generation/vector_quantize_test.py | 13 ++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid_generation/vector_quantize.py index bc429ca75..427cb50fc 100644 --- a/tzrec/modules/sid_generation/vector_quantize.py +++ b/tzrec/modules/sid_generation/vector_quantize.py @@ -142,6 +142,19 @@ def __init__( sinkhorn_epsilon: float = 10.0, ) -> None: super().__init__() + # Sinkhorn + Gumbel-Softmax pick the code by two different rules: + # `ids` come from the Sinkhorn balanced-assignment argmax, while the + # Gumbel branch builds `emb` from argmax(-distances + noise) (nearest + # code). The two indices generally disagree, so the saved SID would not + # match the codebook vector actually reconstructed/trained. STE avoids + # this by looking up embedding(ids) directly. Force a consistent combo. + _is_gumbel = forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + assert not (use_sinkhorn and _is_gumbel), ( + "use_sinkhorn=True is incompatible with forward_mode=GUMBEL_SOFTMAX: " + "Sinkhorn drives `ids` (balanced assignment) while Gumbel drives " + "`emb` (nearest code), so the returned id and embedding diverge. " + "Use STE with Sinkhorn, or Gumbel-Softmax without Sinkhorn." + ) self.embed_dim = embed_dim self.n_embed = n_embed self.forward_mode = forward_mode diff --git a/tzrec/modules/sid_generation/vector_quantize_test.py b/tzrec/modules/sid_generation/vector_quantize_test.py index 833f36231..1558288f7 100644 --- a/tzrec/modules/sid_generation/vector_quantize_test.py +++ b/tzrec/modules/sid_generation/vector_quantize_test.py @@ -26,7 +26,8 @@ class VectorQuantizeTest(unittest.TestCase): ("ste_l2", QuantizeForwardMode.STE, "l2", True), ("ste_cosine", QuantizeForwardMode.STE, "cosine", True), ("ste_no_sinkhorn", QuantizeForwardMode.STE, "l2", False), - ("gumbel_l2", QuantizeForwardMode.GUMBEL_SOFTMAX, "l2", True), + # Gumbel must run without Sinkhorn (the combo is asserted against). + ("gumbel_l2", QuantizeForwardMode.GUMBEL_SOFTMAX, "l2", False), ] ) def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: @@ -46,6 +47,16 @@ def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: self.assertTrue((out.ids >= 0).all() and (out.ids < 16).all()) self.assertTrue(torch.isfinite(out.embeddings).all()) + def test_sinkhorn_gumbel_combo_rejected(self) -> None: + """Sinkhorn + Gumbel would desync `ids` and `emb`; constructor rejects it.""" + with self.assertRaisesRegex(AssertionError, "GUMBEL_SOFTMAX"): + VectorQuantize( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=True, + ) + def test_train_forward_backward_reaches_input(self) -> None: torch.manual_seed(0) vq = VectorQuantize(embed_dim=8, n_embed=16, use_sinkhorn=False) From 3667170737c6c6b8d848671667c81b25eb6cea09 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 3 Jun 2026 02:26:20 +0000 Subject: [PATCH 029/129] [refactor] ResidualQuantizer: hoist shared residual walk into the base MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both backends' forward loops did the same thing — per layer: normalize the residual, assign codes, look up the quantized vector, subtract it, accumulate — differing only in (a) how a layer produces (codes, quant) and (b) VQ's per-layer commitment loss. Consolidate the shared structure: - New abstract primitive `_quantize_layer(layer_idx, residual, temperature)` -> (codes, quant), the encode-direction mirror of `_lookup_code`. K-Means runs predict()+centroids (with the uninitialized-layer zero guard moved inside it); VQ runs the VectorQuantize layer and returns the raw, grad-carrying codebook vector. - New concrete `_residual_pass()` in the base drives the walk and returns (cluster_ids, aggregated, cumulative). The residual subtraction uses `quant.detach()` — required for VQ's gradient semantics, a no-op for K-Means (buffer lookup) — so one line serves both. - `get_codes` is now concrete in the base (mirrors decode_codes), so both subclasses drop their copies. VQ no longer routes get_codes through a full training forward (no wasted loss/STE). - Each forward shrinks to: K-Means returns the final sum; VQ adds init_embed_, maps the commitment loss over `cumulative`, and applies STE/rotation. The per-layer commitment loss stays in VQ.forward (mapped over the returned cumulative quants) rather than leaking a loss hook into the shared walk. Behavior is unchanged: the two forwards remain numerically identical and the predict path stays FX-traceable (the loop is now in `_residual_pass`). Tests: add forward-vs-get_codes consistency for both backends; existing VQ backward / FAISS-init / dist-broadcast / K-Means FX-trace / decode tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../residual_kmeans_quantizer.py | 72 +++++++--------- .../sid_generation/residual_quantizer.py | 76 ++++++++++++++-- .../sid_generation/residual_quantizer_test.py | 26 ++++++ .../residual_vector_quantizer.py | 86 +++++++++---------- 4 files changed, 168 insertions(+), 92 deletions(-) diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py index 0de529a43..2f5a83227 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid_generation/residual_kmeans_quantizer.py @@ -76,12 +76,42 @@ def __init__( ] ) + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Nearest-centroid assignment for one layer. + + Uninitialized layers (before ``train_offline``) return zeros, so the + residual walk is a no-op and the model stays callable. ``temperature`` + is unused (no soft assignment). + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + temperature (float): unused. + + Returns: + codes (Tensor): cluster indices, shape (B,). + quantized (Tensor): selected centroids, shape (B, D). + """ + layer = self.layers[layer_idx] + if not layer.is_initialized: + codes = torch.zeros( + residual.shape[0], dtype=torch.long, device=residual.device + ) + return codes, torch.zeros_like(residual) + codes = layer.predict(residual) + return codes, layer.centroids[codes] + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Assign codes per layer and sum the centroids. Codebook is read-only here; training happens in ``train_offline``. - Uninitialized layers return dummy zeros so the model is callable - before the one-shot FAISS fit completes. + Uninitialized layers contribute zeros (see :meth:`_quantize_layer`) so + the model is callable before the one-shot FAISS fit completes. Args: input (Tensor): input embeddings, shape (B, D). @@ -90,45 +120,9 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: codes (Tensor): cluster indices per layer, shape (B, n_layers). quantized (Tensor): sum of quantized embeddings, shape (B, D). """ - residual = input - all_codes: List[torch.Tensor] = [] - quantized_sum = torch.zeros_like(input) - - for layer in self.layers: - if self.normalize_residuals: - residual = F.normalize(residual, dim=-1) - - if layer.is_initialized: - codes = layer.predict(residual) - quantized = layer.centroids[codes] - residual = residual - quantized - quantized_sum = quantized_sum + quantized - else: - codes = torch.zeros( - input.shape[0], dtype=torch.long, device=input.device - ) - all_codes.append(codes) - - cluster_ids = torch.stack(all_codes, dim=-1) # (B, n_layers) + cluster_ids, quantized_sum, _ = self._residual_pass(input) return cluster_ids, quantized_sum - @torch.no_grad() - def get_codes(self, input: torch.Tensor) -> torch.Tensor: - """Assign semantic IDs without updating centroids.""" - residual = input - all_codes: List[torch.Tensor] = [] - - for layer in self.layers: - if self.normalize_residuals: - residual = F.normalize(residual, dim=-1) - - codes = layer.predict(residual) - all_codes.append(codes) - quantized = layer.centroids[codes] - residual = residual - quantized - - return torch.stack(all_codes, dim=-1) - @torch.no_grad() def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get centroid weights for a specific layer. diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid_generation/residual_quantizer.py index 958514b17..ede8f0444 100644 --- a/tzrec/modules/sid_generation/residual_quantizer.py +++ b/tzrec/modules/sid_generation/residual_quantizer.py @@ -11,10 +11,11 @@ """ResidualQuantizer: abstract base for multi-layer residual quantizers.""" -from typing import List, Union +from typing import List, Tuple, Union import torch from torch import nn +from torch.nn import functional as F def normalize_n_embed(n_embed: Union[int, List[int]], n_layers: int) -> List[int]: @@ -54,10 +55,12 @@ class ResidualQuantizer(nn.Module): Semantic ID = (code_0, code_1, ..., code_{n_layers-1}). This base owns the structural invariants (``embed_dim``, ``n_layers``, - per-layer codebook sizes, residual normalization toggle) and the - backend-agnostic :meth:`decode_codes` / :meth:`output_dim`. Subclasses - build ``self.layers`` and implement :meth:`forward`, :meth:`get_codes`, - :meth:`get_codebook_embeddings`, and :meth:`_lookup_code`. + per-layer codebook sizes, residual normalization toggle) and the shared + residual walk (:meth:`_residual_pass`, :meth:`get_codes`, + :meth:`decode_codes`, :meth:`output_dim`). Subclasses build ``self.layers`` + and implement the per-layer primitives :meth:`_quantize_layer` (encode) and + :meth:`_lookup_code` (decode), plus :meth:`forward` and + :meth:`get_codebook_embeddings`. Args: embed_dim (int): feature / codebook dimension. @@ -90,17 +93,78 @@ def forward(self, input: torch.Tensor): # noqa: ANN201 """Assign codes per layer and accumulate the quantized output.""" raise NotImplementedError + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Assign one layer's codes and look up its quantized vector. + + Backend primitive behind the residual walk (encode-direction mirror of + :meth:`_lookup_code`). ``temperature`` is used only by the VQ backend. + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + temperature (float): Gumbel-Softmax temperature (VQ only). + + Returns: + codes (Tensor): per-layer cluster ids, shape (B,). + quantized (Tensor): the layer's quantized vector, shape (B, D). + """ + raise NotImplementedError + + def _residual_pass( + self, + input: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Shared residual walk: per-layer assign, subtract, accumulate. + + The quantized vector is subtracted detached (keeps the residual chain + gradient-free) and accumulated (keeps gradient when the backend + supplies it, e.g. VQ). + + Args: + input (Tensor): input embeddings, shape (B, D). + temperature (float): forwarded to :meth:`_quantize_layer`. + + Returns: + cluster_ids (Tensor): stacked codes, shape (B, n_layers). + aggregated (Tensor): sum of quantized vectors, shape (B, D). + cumulative (List[Tensor]): running sum after each layer + (``cumulative[-1] is aggregated``). + """ + residual = input + all_codes: List[torch.Tensor] = [] + cumulative: List[torch.Tensor] = [] + aggregated = torch.zeros_like(input) + for i in range(self.n_layers): + if self.normalize_residuals: + residual = F.normalize(residual, dim=-1) + codes, quantized = self._quantize_layer(i, residual, temperature) + all_codes.append(codes) + aggregated = aggregated + quantized + cumulative.append(aggregated) + residual = residual - quantized.detach() + cluster_ids = torch.stack(all_codes, dim=-1) # (B, n_layers) + return cluster_ids, aggregated, cumulative + @torch.no_grad() def get_codes(self, input: torch.Tensor) -> torch.Tensor: """Assign semantic IDs without updating the codebook. + Shared encode-direction mirror of :meth:`decode_codes`. + Args: input (Tensor): input embeddings, shape (B, D). Returns: Tensor: cluster ids, shape (B, n_layers). """ - raise NotImplementedError + cluster_ids, _, _ = self._residual_pass(input) + return cluster_ids @torch.no_grad() def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index e26a7033b..c4d4bfe72 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -88,6 +88,15 @@ def test_get_codes_no_grad(self) -> None: codes = self.rvq.get_codes(torch.randn(4, 8)) self.assertEqual(codes.shape, (4, 3)) + def test_forward_get_codes_consistent_eval(self) -> None: + """get_codes (shared base walk) matches forward's ids in eval.""" + self.rvq.eval() + x = torch.randn(6, 8) + fwd_ids = self.rvq(x).cluster_ids + gc_ids = self.rvq.get_codes(x) + self.assertFalse(gc_ids.requires_grad) + torch.testing.assert_close(gc_ids, fwd_ids) + def test_faiss_kmeans_init_seeds_codebook(self) -> None: try: import faiss # noqa: F401 @@ -180,6 +189,23 @@ def test_train_offline_then_decode(self) -> None: recon = rkq.decode_codes(codes) # inherited from the base self.assertEqual(recon.shape, (5, 4)) + def test_forward_get_codes_consistent(self) -> None: + """Forward ids and get_codes both route through the shared walk.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + x = torch.randn(9, 4) + fwd_ids, fwd_quant = rkq(x) + torch.testing.assert_close(rkq.get_codes(x), fwd_ids) + # forward's residual-sum equals the centroid-sum reconstruction. + torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer.py b/tzrec/modules/sid_generation/residual_vector_quantizer.py index a534550e2..956e343bd 100644 --- a/tzrec/modules/sid_generation/residual_vector_quantizer.py +++ b/tzrec/modules/sid_generation/residual_vector_quantizer.py @@ -11,7 +11,7 @@ """ResidualVectorQuantizer: multi-layer residual VQ with gradient training.""" -from typing import List, Sequence, Union +from typing import List, Sequence, Tuple, Union import torch import torch.distributed as dist @@ -276,6 +276,30 @@ def _apply_rotation_trick( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize one layer's residual via its ``VectorQuantize`` layer. + + Returns the raw (un-STE'd) codebook vector so gradient still flows into + the codebook; STE is applied once on the aggregate in :meth:`forward`. + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + temperature (float): Gumbel-Softmax temperature. + + Returns: + ids (Tensor): per-layer cluster ids, shape (B,). + raw_emb (Tensor): raw codebook vectors (with grad), shape (B, D). + """ + layer = self.layers[layer_idx] + out = layer(residual, temperature=temperature) + return out.ids, layer.embedding(out.ids) + def forward( self, input: torch.Tensor, @@ -285,11 +309,11 @@ def forward( Training flow: 1. If kmeans_init and not initialized -> init_embed_(input) - 2. For each layer: quantize detached residual, accumulate - into aggregated_quants and compute per-layer commitment loss - in-place (avoids storing a quant_list of clones). - 3. Mean of per-layer commitment losses (cos/l2 with latent_weight) - 4. STE gradient pass-through (or rotation trick) + 2. Shared residual walk (:meth:`_residual_pass`) over the detached + input: per-layer assign + grad-carrying accumulation. + 3. Mean of per-layer commitment losses over the cumulative quants + (cos/l2 with latent_weight). + 4. STE gradient pass-through (or rotation trick). Args: input (Tensor): input embeddings, shape (B, D). @@ -303,36 +327,17 @@ def forward( if self.training: self.init_embed_(input) - # Detach residual for VQ assignment (gradient flows via STE only). - residual = input.detach() - all_ids: List[torch.Tensor] = [] - commitment_loss_list: List[torch.Tensor] = [] - aggregated_quants = torch.zeros_like(input) - - # Step 2: per-layer residual quantization - for layer in self.layers: - if self.normalize_residuals: - residual = F.normalize(residual, dim=-1) - - quantized = layer(residual, temperature=temperature) - all_ids.append(quantized.ids) - - # Separate raw lookup: ``quantized.embeddings`` already applies - # STE (gradient -> encoder), but the commitment loss + residual - # update need the un-STE'd codebook vector with gradient still - # flowing into ``layer.embedding.weight``. - raw_emb = layer.embedding(quantized.ids) - residual = residual - raw_emb.detach() - aggregated_quants = aggregated_quants + raw_emb - - commitment_loss_list.append( - self._single_commitment_loss(input, aggregated_quants) - ) - - cluster_ids = torch.stack(all_ids, dim=-1) # (B, n_layers) + # Step 2: shared residual walk on the detached input (encoder grad + # flows only via the STE in step 4; the accumulated quants keep grad + # so the codebook still trains). cumulative[i] = sum after layer i. + cluster_ids, aggregated_quants, cumulative = self._residual_pass( + input.detach(), temperature + ) # Step 3: aggregate per-layer commitment loss - commitment_loss = torch.mean(torch.stack(commitment_loss_list)) + commitment_loss = torch.mean( + torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) + ) # Step 4: STE or rotation trick (quants_trunc = final accumulated) quants_trunc = aggregated_quants @@ -348,19 +353,6 @@ def forward( quantization_loss=commitment_loss, ) - @torch.no_grad() - def get_codes(self, input: torch.Tensor) -> torch.Tensor: - """Assign semantic IDs without gradient computation. - - Args: - input (Tensor): input embeddings, shape (B, D). - - Returns: - Tensor: cluster ids, shape (B, n_layers). - """ - output = self.forward(input) - return output.cluster_ids - @torch.no_grad() def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get codebook embedding weights for a specific layer. From 4b6e9b007eba203b387816580e90e4d220794ad6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 3 Jun 2026 06:57:44 +0000 Subject: [PATCH 030/129] [review] SID base: address github-actions review on PR #538 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - decode_codes: seed the accumulator from the first lookup so device AND dtype follow the codebook, instead of pinning to fp32 (silently upcasting each layer's add under mixed precision). n_layers >= 1 is guaranteed. - BaseSidModel._update_unique_sid_ratio: guard B == 0 (empty final shard under DDP/TorchRec) to avoid ZeroDivisionError. - residual_quantizer_test: add a fake one-primitive subclass exercising the concrete residual walk the base owns — get_codes shape + aggregate == Σ quantized_i, the detach invariant (codebook grad flows, input gets none), the normalize_residuals branch, and decode_codes sum + codebook dtype. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 2 + .../sid_generation/residual_quantizer.py | 12 ++- .../sid_generation/residual_quantizer_test.py | 81 +++++++++++++++++++ 3 files changed, 88 insertions(+), 7 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 35ef27040..2472753ec 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -122,5 +122,7 @@ def _update_unique_sid_ratio(self, codes: torch.Tensor) -> None: codes (Tensor): semantic-ID codes, shape (B, n_layers). """ B = codes.shape[0] + if B == 0: # empty final shard under DDP/TorchRec + return unique_sids = torch.unique(codes, dim=0).shape[0] self._metric_modules["unique_sid_ratio"].update(unique_sids / B) diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid_generation/residual_quantizer.py index ede8f0444..0f0bdd5d5 100644 --- a/tzrec/modules/sid_generation/residual_quantizer.py +++ b/tzrec/modules/sid_generation/residual_quantizer.py @@ -203,12 +203,10 @@ def decode_codes(self, codes: torch.Tensor) -> torch.Tensor: Returns: Tensor: reconstructed embeddings, shape (B, embed_dim). """ - quantized_sum = torch.zeros( - codes.shape[0], - self.embed_dim, - device=codes.device, - dtype=torch.float, - ) - for i in range(self.n_layers): + # Seed from the first lookup so device and dtype follow the codebook + # (avoids pinning the sum to fp32 under mixed precision). n_layers >= 1 + # is guaranteed by the codebook config. + quantized_sum = self._lookup_code(0, codes[:, 0]) + for i in range(1, self.n_layers): quantized_sum = quantized_sum + self._lookup_code(i, codes[:, i]) return quantized_sum diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index c4d4bfe72..021e396ba 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -12,6 +12,7 @@ import unittest import torch +from torch import nn from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, @@ -61,6 +62,86 @@ def test_abstract_primitives_raise(self) -> None: rq.decode_codes(torch.zeros(3, 2, dtype=torch.long)) +class _FakeQuantizer(ResidualQuantizer): + """Minimal concrete subclass to exercise the base residual walk. + + Implements only the two per-layer primitives over a learnable codebook, + so the base's _residual_pass / get_codes / decode_codes can be tested + without pulling in the K-Means or VQ backends. + """ + + def __init__(self, embed_dim, n_layers, n_embed=5, normalize_residuals=False): + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) + self.books = nn.ParameterList( + [ + nn.Parameter(torch.randn(self.n_embed_list[i], embed_dim)) + for i in range(n_layers) + ] + ) + + def _quantize_layer(self, layer_idx, residual, temperature=1.0): + codes = (residual.detach() @ self.books[layer_idx].t()).argmax(dim=-1) + return codes, self.books[layer_idx][codes] + + def _lookup_code(self, layer_idx, code_idx): + return self.books[layer_idx][code_idx] + + def forward(self, input): + return self._residual_pass(input) + + def get_codebook_embeddings(self, layer_idx): + return self.books[layer_idx] + + +class ResidualQuantizerWalkTest(unittest.TestCase): + """Exercise the concrete residual walk the base owns (via a fake backend).""" + + def test_residual_pass_shapes_and_aggregate(self) -> None: + torch.manual_seed(0) + fq = _FakeQuantizer(embed_dim=4, n_layers=3, n_embed=5) + x = torch.randn(6, 4) + ids, agg, cum = fq._residual_pass(x) + self.assertEqual(ids.shape, (6, 3)) + self.assertEqual(fq.get_codes(x).shape, (6, 3)) + manual = sum(fq._lookup_code(i, ids[:, i]) for i in range(3)) + torch.testing.assert_close(agg, manual) # aggregated == Σ quantized_i + self.assertTrue(torch.equal(cum[-1], agg)) + + def test_detach_invariant(self) -> None: + torch.manual_seed(0) + fq = _FakeQuantizer(embed_dim=4, n_layers=2, n_embed=5) + x = torch.randn(5, 4, requires_grad=True) + _, agg, _ = fq._residual_pass(x) + # Codebook grad flows, but the residual chain is detached, so the + # input receives no gradient. + self.assertTrue(agg.requires_grad) + agg.sum().backward() + self.assertIsNotNone(fq.books[0].grad) + self.assertIsNone(x.grad) + + def test_normalize_residuals_branch(self) -> None: + torch.manual_seed(0) + fq = _FakeQuantizer( + embed_dim=4, n_layers=2, n_embed=5, normalize_residuals=True + ) + ids, agg, _ = fq._residual_pass(torch.randn(5, 4)) + self.assertEqual(ids.shape, (5, 2)) + self.assertEqual(agg.shape, (5, 4)) + + def test_decode_codes_sum_and_dtype(self) -> None: + torch.manual_seed(0) + fq = _FakeQuantizer(embed_dim=4, n_layers=3, n_embed=5) + codes = torch.randint(0, 5, (6, 3)) + recon = fq.decode_codes(codes) + self.assertEqual(recon.shape, (6, 4)) + manual = sum(fq.books[i][codes[:, i]] for i in range(3)) + torch.testing.assert_close(recon, manual) + # device/dtype follow the codebook (regression for the fp32-pin fix). + fq16 = _FakeQuantizer(embed_dim=4, n_layers=2, n_embed=5).to(torch.bfloat16) + recon16 = fq16.decode_codes(torch.randint(0, 5, (3, 2))) + self.assertEqual(recon16.dtype, torch.bfloat16) + + class ResidualVectorQuantizerTest(unittest.TestCase): def setUp(self) -> None: torch.manual_seed(0) From f568de0dd0ee953ab828755bc3bd91c291788136 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 3 Jun 2026 09:19:04 +0000 Subject: [PATCH 031/129] [review] SID base: second round of github-actions feedback on PR #538 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ResidualQuantizer.__init__: assert n_layers >= 1, making the boundary self-guarding (matches the "n_layers >= 1 is guaranteed" comment in decode_codes; a future n_layers=0 subclass would otherwise index OOB). - BaseSidModel.update_train_metric: type predictions as Dict[str, Tensor] instead of bare dict, matching BaseModel / sibling models. - sid_model_test.py (new): unit-test _update_unique_sid_ratio via __new__ — empty-batch no-op (guard) and the exact 0.75 ratio on a known-duplicate batch. - residual_quantizer_test: make the normalize_residuals test behavioral — toggle the flag on the same input/codebook and assert the assignments differ, so the normalization branch is actually validated (the old shape assertions passed even with the branch deleted). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 4 +- tzrec/models/sid_model_test.py | 53 +++++++++++++++++++ .../sid_generation/residual_quantizer.py | 1 + .../sid_generation/residual_quantizer_test.py | 21 +++++--- 4 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 tzrec/models/sid_model_test.py diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 2472753ec..b25c9d81b 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -11,7 +11,7 @@ """BaseSidModel: shared base for semantic-ID generation models.""" -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional import torch import torchmetrics @@ -105,7 +105,7 @@ def init_metric(self) -> None: def update_train_metric( self, - predictions: dict, + predictions: Dict[str, torch.Tensor], batch: Batch, ) -> None: """Update train-path metric state. diff --git a/tzrec/models/sid_model_test.py b/tzrec/models/sid_model_test.py new file mode 100644 index 000000000..da7a1faee --- /dev/null +++ b/tzrec/models/sid_model_test.py @@ -0,0 +1,53 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torchmetrics + +from tzrec.models.sid_model import BaseSidModel + + +class UpdateUniqueSidRatioTest(unittest.TestCase): + """Unit-test the codebook-coverage metric helper. + + ``_update_unique_sid_ratio`` is pure tensor logic over + ``self._metric_modules`` (no proto dependency), so it is testable without + a full ``BaseSidModel`` config — full model coverage waits on the concrete + subclasses in the follow-up PRs. + """ + + def _bare_model(self) -> BaseSidModel: + # Bypass __init__ (which needs a pipeline config); only the metric + # module the helper touches needs to exist. + model = BaseSidModel.__new__(BaseSidModel) + model._metric_modules = {"unique_sid_ratio": torchmetrics.MeanMetric()} + return model + + def test_empty_batch_is_noop(self) -> None: + model = self._bare_model() + model._update_unique_sid_ratio(torch.empty(0, 3, dtype=torch.long)) + # The B == 0 guard returns early -> no sample recorded. + self.assertEqual(model._metric_modules["unique_sid_ratio"].weight.item(), 0.0) + + def test_ratio_on_known_duplicates(self) -> None: + model = self._bare_model() + # 3 unique rows out of 4 -> ratio 0.75. + codes = torch.tensor([[1, 2], [1, 2], [3, 4], [5, 6]]) + model._update_unique_sid_ratio(codes) + self.assertAlmostEqual( + model._metric_modules["unique_sid_ratio"].compute().item(), 0.75, places=6 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid_generation/residual_quantizer.py index 0f0bdd5d5..6b80f2e33 100644 --- a/tzrec/modules/sid_generation/residual_quantizer.py +++ b/tzrec/modules/sid_generation/residual_quantizer.py @@ -78,6 +78,7 @@ def __init__( normalize_residuals: bool = False, ) -> None: super().__init__() + assert n_layers >= 1, f"n_layers must be >= 1, got {n_layers}" self.embed_dim = embed_dim self.n_layers = n_layers self.normalize_residuals = normalize_residuals diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid_generation/residual_quantizer_test.py index 021e396ba..fe0d2d189 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid_generation/residual_quantizer_test.py @@ -119,14 +119,21 @@ def test_detach_invariant(self) -> None: self.assertIsNotNone(fq.books[0].grad) self.assertIsNone(x.grad) - def test_normalize_residuals_branch(self) -> None: - torch.manual_seed(0) - fq = _FakeQuantizer( - embed_dim=4, n_layers=2, n_embed=5, normalize_residuals=True + def test_normalize_residuals_changes_assignment(self) -> None: + # Same input and same codebook (re-seeded before each build), so the + # only difference is the normalize_residuals branch — it must change + # the residual a later layer sees and hence the codes it assigns. + x = torch.randn(8, 4) + torch.manual_seed(1) + fq_off = _FakeQuantizer(embed_dim=4, n_layers=2, n_embed=6) + torch.manual_seed(1) + fq_on = _FakeQuantizer( + embed_dim=4, n_layers=2, n_embed=6, normalize_residuals=True ) - ids, agg, _ = fq._residual_pass(torch.randn(5, 4)) - self.assertEqual(ids.shape, (5, 2)) - self.assertEqual(agg.shape, (5, 4)) + ids_off, _, _ = fq_off._residual_pass(x) + ids_on, _, _ = fq_on._residual_pass(x) + self.assertEqual(ids_on.shape, (8, 2)) + self.assertFalse(torch.equal(ids_off, ids_on)) def test_decode_codes_sum_and_dtype(self) -> None: torch.manual_seed(0) From 9dc276da15b7ad202bb30537fdc7b65cde8065d3 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 3 Jun 2026 09:35:19 +0000 Subject: [PATCH 032/129] [review] SID: address tiankongdeguiji review on PR #538 - Metrics: switch the shared reconstruction metric to torchmetrics.MeanSquaredError (correct (preds, target) aggregation vs the biased mean-of-batch-means a MeanMetric gave), and add a UniqueRatio metric class (tzrec/metrics/unique_ratio.py) for codebook coverage instead of a MeanMetric fed a hand-computed ratio. The empty-batch guard + DDP reduction live in the metric; callers update it directly (the one-line _update_unique_sid_ratio passthrough is removed). The RQ-VAE train-path mse moves to MeanSquaredError too. - __init__.py: revert tzrec/modules/sid_generation/__init__.py to a bare package marker (no re-exports) and import ResidualKMeansQuantizer from its submodule in sid_rqkmeans, per "avoid adding to __init__.py". Tests: tzrec/metrics/unique_ratio_test.py (ratio, empty no-op, mean over batches). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/unique_ratio.py | 48 +++++++++++++++++++++ tzrec/metrics/unique_ratio_test.py | 40 ++++++++++++++++++ tzrec/models/sid_model.py | 17 ++------ tzrec/models/sid_model_test.py | 53 ------------------------ tzrec/models/sid_rqkmeans.py | 14 +++++-- tzrec/models/sid_rqvae.py | 10 ++--- tzrec/modules/sid_generation/__init__.py | 32 -------------- 7 files changed, 105 insertions(+), 109 deletions(-) create mode 100644 tzrec/metrics/unique_ratio.py create mode 100644 tzrec/metrics/unique_ratio_test.py delete mode 100644 tzrec/models/sid_model_test.py diff --git a/tzrec/metrics/unique_ratio.py b/tzrec/metrics/unique_ratio.py new file mode 100644 index 000000000..63d7a4126 --- /dev/null +++ b/tzrec/metrics/unique_ratio.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics import Metric + + +class UniqueRatio(Metric): + """Codebook-coverage metric: mean of per-batch (unique rows / batch size). + + Each ``update`` counts the unique rows of a ``(B, n_layers)`` semantic-ID + code tensor and accumulates the per-batch ratio; ``compute`` returns the + running mean. Empty batches (``B == 0``, e.g. an empty final DDP/TorchRec + shard) are skipped. States reduce by ``sum`` across ranks. + """ + + higher_is_better = True + is_differentiable = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.add_state("ratio_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, codes: torch.Tensor) -> None: + """Accumulate the unique-ratio of one batch of codes. + + Args: + codes (Tensor): semantic-ID codes, shape (B, n_layers). + """ + batch_size = codes.shape[0] + if batch_size == 0: + return + unique = torch.unique(codes, dim=0).shape[0] + self.ratio_sum += unique / batch_size + self.count += 1 + + def compute(self) -> torch.Tensor: + """Mean per-batch unique ratio (NaN before any non-empty update).""" + return self.ratio_sum / self.count diff --git a/tzrec/metrics/unique_ratio_test.py b/tzrec/metrics/unique_ratio_test.py new file mode 100644 index 000000000..83e89ecd6 --- /dev/null +++ b/tzrec/metrics/unique_ratio_test.py @@ -0,0 +1,40 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.metrics.unique_ratio import UniqueRatio + + +class UniqueRatioTest(unittest.TestCase): + def test_known_duplicates(self) -> None: + metric = UniqueRatio() + # 3 unique rows out of 4 -> 0.75. + metric.update(torch.tensor([[1, 2], [1, 2], [3, 4], [5, 6]])) + self.assertAlmostEqual(metric.compute().item(), 0.75, places=6) + + def test_empty_batch_skipped(self) -> None: + metric = UniqueRatio() + metric.update(torch.empty(0, 3, dtype=torch.long)) + self.assertEqual(metric.count.item(), 0.0) + self.assertTrue(torch.isnan(metric.compute())) + + def test_mean_over_batches(self) -> None: + metric = UniqueRatio() + metric.update(torch.tensor([[1, 1], [1, 1]])) # 1/2 = 0.5 + metric.update(torch.tensor([[1, 1], [2, 2]])) # 2/2 = 1.0 + self.assertAlmostEqual(metric.compute().item(), 0.75, places=6) # mean + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index b25c9d81b..394616f87 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -18,6 +18,7 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature +from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel from tzrec.protos.model_pb2 import ModelConfig @@ -100,8 +101,8 @@ def init_metric(self) -> None: ``unique_sid_ratio``: codebook coverage = unique SIDs / batch size. Subclasses call ``super().init_metric()`` then add their extras. """ - self._metric_modules["mse"] = torchmetrics.MeanMetric() - self._metric_modules["unique_sid_ratio"] = torchmetrics.MeanMetric() + self._metric_modules["mse"] = torchmetrics.MeanSquaredError() + self._metric_modules["unique_sid_ratio"] = UniqueRatio() def update_train_metric( self, @@ -114,15 +115,3 @@ def update_train_metric( with a meaningful train signal (RQ-VAE) override this. """ return - - def _update_unique_sid_ratio(self, codes: torch.Tensor) -> None: - """Update the codebook-coverage metric (unique SIDs / batch size). - - Args: - codes (Tensor): semantic-ID codes, shape (B, n_layers). - """ - B = codes.shape[0] - if B == 0: # empty final shard under DDP/TorchRec - return - unique_sids = torch.unique(codes, dim=0).shape[0] - self._metric_modules["unique_sid_ratio"].update(unique_sids / B) diff --git a/tzrec/models/sid_model_test.py b/tzrec/models/sid_model_test.py deleted file mode 100644 index da7a1faee..000000000 --- a/tzrec/models/sid_model_test.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -import torchmetrics - -from tzrec.models.sid_model import BaseSidModel - - -class UpdateUniqueSidRatioTest(unittest.TestCase): - """Unit-test the codebook-coverage metric helper. - - ``_update_unique_sid_ratio`` is pure tensor logic over - ``self._metric_modules`` (no proto dependency), so it is testable without - a full ``BaseSidModel`` config — full model coverage waits on the concrete - subclasses in the follow-up PRs. - """ - - def _bare_model(self) -> BaseSidModel: - # Bypass __init__ (which needs a pipeline config); only the metric - # module the helper touches needs to exist. - model = BaseSidModel.__new__(BaseSidModel) - model._metric_modules = {"unique_sid_ratio": torchmetrics.MeanMetric()} - return model - - def test_empty_batch_is_noop(self) -> None: - model = self._bare_model() - model._update_unique_sid_ratio(torch.empty(0, 3, dtype=torch.long)) - # The B == 0 guard returns early -> no sample recorded. - self.assertEqual(model._metric_modules["unique_sid_ratio"].weight.item(), 0.0) - - def test_ratio_on_known_duplicates(self) -> None: - model = self._bare_model() - # 3 unique rows out of 4 -> ratio 0.75. - codes = torch.tensor([[1, 2], [1, 2], [3, 4], [5, 6]]) - model._update_unique_sid_ratio(codes) - self.assertAlmostEqual( - model._metric_modules["unique_sid_ratio"].compute().item(), 0.75, places=6 - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index b1a1b07f1..390226053 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,8 +27,10 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation import ResidualKMeansQuantizer from tzrec.modules.sid_generation.kmeans import recon_diagnostics +from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils import config_util from tzrec.utils.logging_util import logger @@ -255,14 +257,18 @@ def update_metric( losses (dict, optional): a dict of loss. """ if "input_embedding" in predictions: - mse, rel = recon_diagnostics( + _, rel = recon_diagnostics( predictions["input_embedding"], predictions["quantized"], ) - self._metric_modules["mse"].update(mse) + # MeanSquaredError aggregates (preds, target) itself; rel_loss has + # no torchmetrics equivalent so it stays a MeanMetric. + self._metric_modules["mse"].update( + predictions["quantized"], predictions["input_embedding"] + ) self._metric_modules["rel_loss"].update(rel) - self._update_unique_sid_ratio(predictions["codes"]) + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) @torch.no_grad() def on_train_end(self) -> None: diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 94531d31c..3ff92e450 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -327,7 +327,7 @@ def init_metric(self) -> None: # is intentionally eval-only: torch.unique(codes, dim=0).shape[0] # forces a GPU->host sync every step, and codebook coverage is a # diagnostic, not a training signal. - self._train_metric_modules["mse"] = torchmetrics.MeanMetric() + self._train_metric_modules["mse"] = torchmetrics.MeanSquaredError() def update_train_metric( self, @@ -342,8 +342,7 @@ def update_train_metric( """ if "x_hat" in predictions: embedding = self._extract_feature(batch) - mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") - self._train_metric_modules["mse"].update(mse) + self._train_metric_modules["mse"].update(predictions["x_hat"], embedding) def update_metric( self, @@ -360,7 +359,6 @@ def update_metric( """ if "x_hat" in predictions: embedding = self._extract_feature(batch) - mse = F.mse_loss(predictions["x_hat"], embedding, reduction="mean") - self._metric_modules["mse"].update(mse) + self._metric_modules["mse"].update(predictions["x_hat"], embedding) - self._update_unique_sid_ratio(predictions["codes"]) + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid_generation/__init__.py index d6c3e9350..eedc773bc 100644 --- a/tzrec/modules/sid_generation/__init__.py +++ b/tzrec/modules/sid_generation/__init__.py @@ -8,35 +8,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from tzrec.modules.sid_generation.kmeans import ( - KMeansLayer, -) -from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( - ResidualKMeansQuantizer, -) -from tzrec.modules.sid_generation.residual_quantizer import ( - ResidualQuantizer, -) -from tzrec.modules.sid_generation.residual_vector_quantizer import ( - ResidualVectorQuantizer, -) -from tzrec.modules.sid_generation.types import ( - QuantizeForwardMode, - QuantizeOutput, - ResidualQuantizerOutput, -) -from tzrec.modules.sid_generation.vector_quantize import ( - VectorQuantize, -) - -__all__ = [ - "QuantizeForwardMode", - "QuantizeOutput", - "ResidualQuantizerOutput", - "VectorQuantize", - "ResidualQuantizer", - "ResidualVectorQuantizer", - "KMeansLayer", - "ResidualKMeansQuantizer", -] From 98019319d56df49b311501a054c9052c58f05d51 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 3 Jun 2026 10:42:37 +0000 Subject: [PATCH 033/129] [review] SID: honest unique_sid_ratio framing + update_metric docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses github-actions review on PR #538 (no behavior change): - UniqueRatio stays the cheap per-batch metric (mean of distinct-rows / batch-size, two scalar states) — deliberately kept per-batch, not global distinct-SID coverage, to avoid accumulating an O(#distinct-SIDs) set during predict. Docstrings now say so plainly (a batch-size-sensitive diversity proxy, not global codebook coverage) instead of overselling it as "codebook coverage". - BaseSidModel docstring: subclasses extend init_metric (via super) and *implement* update_metric (BaseModel.update_metric raises NotImplementedError, so there is nothing to "extend"); update_train_metric defaults to a no-op. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/unique_ratio.py | 16 +++++++++------- tzrec/metrics/unique_ratio_test.py | 17 +++++++++-------- tzrec/models/sid_model.py | 13 ++++++++----- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/tzrec/metrics/unique_ratio.py b/tzrec/metrics/unique_ratio.py index 63d7a4126..8adb5e398 100644 --- a/tzrec/metrics/unique_ratio.py +++ b/tzrec/metrics/unique_ratio.py @@ -14,12 +14,14 @@ class UniqueRatio(Metric): - """Codebook-coverage metric: mean of per-batch (unique rows / batch size). - - Each ``update`` counts the unique rows of a ``(B, n_layers)`` semantic-ID - code tensor and accumulates the per-batch ratio; ``compute`` returns the - running mean. Empty batches (``B == 0``, e.g. an empty final DDP/TorchRec - shard) are skipped. States reduce by ``sum`` across ranks. + """Mean per-batch unique-SID ratio (distinct rows / batch size). + + Averages, over batches, the fraction of distinct semantic-ID rows in each + batch. It is a cheap (two-scalar state) **diversity proxy**, NOT global + codebook coverage: a SID repeated across different batches counts as + distinct in each, and smaller batches bias the value toward 1.0. Empty + batches are skipped; the per-rank sums reduce by ``sum`` (a count-weighted + mean). """ higher_is_better = True @@ -31,7 +33,7 @@ def __init__(self, **kwargs) -> None: self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, codes: torch.Tensor) -> None: - """Accumulate the unique-ratio of one batch of codes. + """Accumulate one batch's distinct-row ratio. Args: codes (Tensor): semantic-ID codes, shape (B, n_layers). diff --git a/tzrec/metrics/unique_ratio_test.py b/tzrec/metrics/unique_ratio_test.py index 83e89ecd6..70807ef6d 100644 --- a/tzrec/metrics/unique_ratio_test.py +++ b/tzrec/metrics/unique_ratio_test.py @@ -17,24 +17,25 @@ class UniqueRatioTest(unittest.TestCase): - def test_known_duplicates(self) -> None: + def test_single_batch_ratio(self) -> None: metric = UniqueRatio() - # 3 unique rows out of 4 -> 0.75. + # 3 distinct rows out of 4 -> 0.75. metric.update(torch.tensor([[1, 2], [1, 2], [3, 4], [5, 6]])) self.assertAlmostEqual(metric.compute().item(), 0.75, places=6) + def test_mean_over_batches(self) -> None: + metric = UniqueRatio() + metric.update(torch.tensor([[1, 1], [1, 1]])) # 1/2 = 0.5 + metric.update(torch.tensor([[1, 1], [2, 2]])) # 2/2 = 1.0 + # Per-batch mean = 0.75 (a global distinct/total would give 0.5). + self.assertAlmostEqual(metric.compute().item(), 0.75, places=6) + def test_empty_batch_skipped(self) -> None: metric = UniqueRatio() metric.update(torch.empty(0, 3, dtype=torch.long)) self.assertEqual(metric.count.item(), 0.0) self.assertTrue(torch.isnan(metric.compute())) - def test_mean_over_batches(self) -> None: - metric = UniqueRatio() - metric.update(torch.tensor([[1, 1], [1, 1]])) # 1/2 = 0.5 - metric.update(torch.tensor([[1, 1], [2, 2]])) # 2/2 = 1.0 - self.assertAlmostEqual(metric.compute().item(), 0.75, places=6) # mean - if __name__ == "__main__": unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 394616f87..973fcf99f 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -35,12 +35,14 @@ class BaseSidModel(BaseModel): and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``), - reading the item-embedding feature out of ``Batch.dense_features``, - the eval metrics every SID model reports — reconstruction ``mse`` and - ``unique_sid_ratio`` (codebook coverage). + ``unique_sid_ratio`` (mean per-batch unique-SID ratio, a diversity + proxy). Subclasses build their quantizer in ``__init__`` (after calling ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. - They extend :meth:`init_metric` / :meth:`update_metric` with any - backend-specific metrics. + They extend :meth:`init_metric` (via ``super()``) and implement + :meth:`update_metric` to populate the registered metrics + (:meth:`update_train_metric` defaults to a no-op). Args: model_config (ModelConfig): an instance of ModelConfig. @@ -98,8 +100,9 @@ def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. ``mse``: reconstruction error (input vs. quantized / decoded). - ``unique_sid_ratio``: codebook coverage = unique SIDs / batch size. - Subclasses call ``super().init_metric()`` then add their extras. + ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows / + batch size; a batch-size-sensitive diversity proxy, not global + coverage). Subclasses call ``super().init_metric()`` then add extras. """ self._metric_modules["mse"] = torchmetrics.MeanSquaredError() self._metric_modules["unique_sid_ratio"] = UniqueRatio() From e97e742c80c01bd12de461861a4d94a4444d7ed4 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Fri, 5 Jun 2026 03:54:52 +0000 Subject: [PATCH 034/129] [refactor] rename tzrec/modules/sid_generation -> tzrec/modules/sid; bump version to 1.2.16 Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 4 ++-- tzrec/models/sid_rqvae.py | 2 +- tzrec/models/sid_rqvae_test.py | 4 ++-- tzrec/modules/{sid_generation => sid}/__init__.py | 0 tzrec/modules/{sid_generation => sid}/kmeans.py | 0 tzrec/modules/{sid_generation => sid}/kmeans_test.py | 2 +- .../{sid_generation => sid}/residual_kmeans_quantizer.py | 4 ++-- .../modules/{sid_generation => sid}/residual_quantizer.py | 0 .../{sid_generation => sid}/residual_quantizer_test.py | 8 ++++---- .../{sid_generation => sid}/residual_vector_quantizer.py | 8 ++++---- .../residual_vector_quantizer_dist_test.py | 2 +- tzrec/modules/{sid_generation => sid}/types.py | 0 tzrec/modules/{sid_generation => sid}/vector_quantize.py | 4 ++-- .../{sid_generation => sid}/vector_quantize_test.py | 4 ++-- tzrec/version.py | 2 +- 15 files changed, 22 insertions(+), 22 deletions(-) rename tzrec/modules/{sid_generation => sid}/__init__.py (100%) rename tzrec/modules/{sid_generation => sid}/kmeans.py (100%) rename tzrec/modules/{sid_generation => sid}/kmeans_test.py (98%) rename tzrec/modules/{sid_generation => sid}/residual_kmeans_quantizer.py (98%) rename tzrec/modules/{sid_generation => sid}/residual_quantizer.py (100%) rename tzrec/modules/{sid_generation => sid}/residual_quantizer_test.py (97%) rename tzrec/modules/{sid_generation => sid}/residual_vector_quantizer.py (98%) rename tzrec/modules/{sid_generation => sid}/residual_vector_quantizer_dist_test.py (98%) rename tzrec/modules/{sid_generation => sid}/types.py (100%) rename tzrec/modules/{sid_generation => sid}/vector_quantize.py (98%) rename tzrec/modules/{sid_generation => sid}/vector_quantize_test.py (95%) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 390226053..b9c3c8800 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,8 +27,8 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation.kmeans import recon_diagnostics -from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( +from tzrec.modules.sid.kmeans import recon_diagnostics +from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) from tzrec.protos.model_pb2 import ModelConfig diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 3ff92e450..644e61c74 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -29,7 +29,7 @@ from tzrec.features.feature import BaseFeature from tzrec.loss.clip_loss import MaskedCLIPLoss from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid_generation.residual_vector_quantizer import ( +from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) from tzrec.protos.model_pb2 import ModelConfig diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 56542dd33..7ed681dfb 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -310,7 +310,7 @@ def test_commitment_loss_l1_branch(self) -> None: Previously ``"l1"`` silently fell through to the L2 branch. """ - from tzrec.modules.sid_generation.residual_vector_quantizer import ( + from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) @@ -374,7 +374,7 @@ def test_sinkhorn_config_default_enabled(self) -> None: def test_commitment_loss_invalid_raises(self) -> None: """ResidualVectorQuantizer rejects unknown commitment_loss spellings.""" - from tzrec.modules.sid_generation.residual_vector_quantizer import ( + from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) diff --git a/tzrec/modules/sid_generation/__init__.py b/tzrec/modules/sid/__init__.py similarity index 100% rename from tzrec/modules/sid_generation/__init__.py rename to tzrec/modules/sid/__init__.py diff --git a/tzrec/modules/sid_generation/kmeans.py b/tzrec/modules/sid/kmeans.py similarity index 100% rename from tzrec/modules/sid_generation/kmeans.py rename to tzrec/modules/sid/kmeans.py diff --git a/tzrec/modules/sid_generation/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py similarity index 98% rename from tzrec/modules/sid_generation/kmeans_test.py rename to tzrec/modules/sid/kmeans_test.py index 52d685b0e..8fed1f83a 100644 --- a/tzrec/modules/sid_generation/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -13,7 +13,7 @@ import torch -from tzrec.modules.sid_generation.kmeans import ( +from tzrec.modules.sid.kmeans import ( KMeansLayer, _squared_euclidean_distance, faiss_residual_kmeans, diff --git a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py similarity index 98% rename from tzrec/modules/sid_generation/residual_kmeans_quantizer.py rename to tzrec/modules/sid/residual_kmeans_quantizer.py index 2f5a83227..505a1b1dc 100644 --- a/tzrec/modules/sid_generation/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -22,8 +22,8 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid_generation.kmeans import KMeansLayer, recon_diagnostics -from tzrec.modules.sid_generation.residual_quantizer import ResidualQuantizer +from tzrec.modules.sid.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger diff --git a/tzrec/modules/sid_generation/residual_quantizer.py b/tzrec/modules/sid/residual_quantizer.py similarity index 100% rename from tzrec/modules/sid_generation/residual_quantizer.py rename to tzrec/modules/sid/residual_quantizer.py diff --git a/tzrec/modules/sid_generation/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py similarity index 97% rename from tzrec/modules/sid_generation/residual_quantizer_test.py rename to tzrec/modules/sid/residual_quantizer_test.py index fe0d2d189..346c43c41 100644 --- a/tzrec/modules/sid_generation/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,17 +14,17 @@ import torch from torch import nn -from tzrec.modules.sid_generation.residual_kmeans_quantizer import ( +from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) -from tzrec.modules.sid_generation.residual_quantizer import ( +from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, ) -from tzrec.modules.sid_generation.residual_vector_quantizer import ( +from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) -from tzrec.modules.sid_generation.types import ResidualQuantizerOutput +from tzrec.modules.sid.types import ResidualQuantizerOutput class NormalizeNEmbedTest(unittest.TestCase): diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py similarity index 98% rename from tzrec/modules/sid_generation/residual_vector_quantizer.py rename to tzrec/modules/sid/residual_vector_quantizer.py index 956e343bd..20f534a36 100644 --- a/tzrec/modules/sid_generation/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -18,13 +18,13 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid_generation.kmeans import faiss_residual_kmeans -from tzrec.modules.sid_generation.residual_quantizer import ResidualQuantizer -from tzrec.modules.sid_generation.types import ( +from tzrec.modules.sid.kmeans import faiss_residual_kmeans +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer +from tzrec.modules.sid.types import ( QuantizeForwardMode, ResidualQuantizerOutput, ) -from tzrec.modules.sid_generation.vector_quantize import VectorQuantize +from tzrec.modules.sid.vector_quantize import VectorQuantize from tzrec.utils.logging_util import logger diff --git a/tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py b/tzrec/modules/sid/residual_vector_quantizer_dist_test.py similarity index 98% rename from tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py rename to tzrec/modules/sid/residual_vector_quantizer_dist_test.py index b36e182a4..4065e943d 100644 --- a/tzrec/modules/sid_generation/residual_vector_quantizer_dist_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_dist_test.py @@ -25,7 +25,7 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tzrec.modules.sid_generation.residual_vector_quantizer import ( +from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) from tzrec.utils import misc_util diff --git a/tzrec/modules/sid_generation/types.py b/tzrec/modules/sid/types.py similarity index 100% rename from tzrec/modules/sid_generation/types.py rename to tzrec/modules/sid/types.py diff --git a/tzrec/modules/sid_generation/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py similarity index 98% rename from tzrec/modules/sid_generation/vector_quantize.py rename to tzrec/modules/sid/vector_quantize.py index 427cb50fc..16ec0d629 100644 --- a/tzrec/modules/sid_generation/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -18,8 +18,8 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid_generation.kmeans import _squared_euclidean_distance -from tzrec.modules.sid_generation.types import ( +from tzrec.modules.sid.kmeans import _squared_euclidean_distance +from tzrec.modules.sid.types import ( QuantizeForwardMode, QuantizeOutput, ) diff --git a/tzrec/modules/sid_generation/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py similarity index 95% rename from tzrec/modules/sid_generation/vector_quantize_test.py rename to tzrec/modules/sid/vector_quantize_test.py index 1558288f7..4df9208dc 100644 --- a/tzrec/modules/sid_generation/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -14,8 +14,8 @@ import torch from parameterized import parameterized -from tzrec.modules.sid_generation.types import QuantizeForwardMode -from tzrec.modules.sid_generation.vector_quantize import VectorQuantize +from tzrec.modules.sid.types import QuantizeForwardMode +from tzrec.modules.sid.vector_quantize import VectorQuantize class VectorQuantizeTest(unittest.TestCase): diff --git a/tzrec/version.py b/tzrec/version.py index 4ac3e6f3c..8bb46cf62 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.15" +__version__ = "1.2.16" From 995b23e48c26bd92e66403d2199b3206006d7efd Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Fri, 5 Jun 2026 07:10:27 +0000 Subject: [PATCH 035/129] [feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means) Second of three PRs splitting the Semantic-ID models onto the shared base from #538. Adds the concrete RQ-KMeans backend on top of ResidualQuantizer / BaseSidModel; RQ-VAE follows in PR3. - tzrec/modules/sid/kmeans.py: KMeansLayer centroid container + recon_diagnostics. - tzrec/modules/sid/residual_kmeans_quantizer.py: ResidualKMeansQuantizer (FAISS-trained, FX-traceable forward, non-uniform per-layer codebooks). - tzrec/models/sid_rqkmeans.py: SidRqkmeans(BaseSidModel) - gradient -free; reservoir-samples embeddings during the train loop and fits FAISS once in on_train_end. - tzrec/models/model.py: BaseModel.on_train_end() no-op lifecycle hook. - tzrec/main.py: invoke on_train_end after the train loop and force the tail checkpoint so post-hook state is persisted. - protos: SidRqkmeans message + ModelConfig registration (601; 600 is reserved for SidRqvae in PR3). - tests: kmeans_test, ResidualKMeansQuantizerTest, sid_rqkmeans_test. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 9 + tzrec/models/model.py | 9 + tzrec/models/sid_rqkmeans.py | 343 ++++++++++++++++++ tzrec/models/sid_rqkmeans_test.py | 306 ++++++++++++++++ tzrec/modules/sid/kmeans.py | 222 ++++++++++++ tzrec/modules/sid/kmeans_test.py | 100 +++++ .../modules/sid/residual_kmeans_quantizer.py | 248 +++++++++++++ tzrec/modules/sid/residual_quantizer_test.py | 92 +++++ tzrec/protos/model.proto | 5 + tzrec/protos/models/sid_model.proto | 31 ++ 10 files changed, 1365 insertions(+) create mode 100644 tzrec/models/sid_rqkmeans.py create mode 100644 tzrec/models/sid_rqkmeans_test.py create mode 100644 tzrec/modules/sid/kmeans.py create mode 100644 tzrec/modules/sid/kmeans_test.py create mode 100644 tzrec/modules/sid/residual_kmeans_quantizer.py create mode 100644 tzrec/protos/models/sid_model.proto diff --git a/tzrec/main.py b/tzrec/main.py index 87f2984fb..8824e8373 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -500,6 +500,15 @@ def _train_and_evaluate( if lr.by_epoch: lr.step() + # One-shot end-of-loop hook (default no-op). Some models do real work + # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings + # collected during training. Since that mutates model state, force the + # tail-save below to fire so the post-hook state is persisted even when + # the last in-loop checkpoint coincided with the final step. + _model.on_train_end() + if last_ckpt_step == i_step: + last_ckpt_step = -1 + _log_train( i_step, losses, diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 40da5335a..10fa8aae5 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,6 +150,15 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results + def on_train_end(self) -> None: + """Hook fired once after the train_eval loop exits. + + Default: no-op. Override in models that need one-shot end-of-loop + work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS + codebook from the embedding sample it collected during training. + """ + pass + def sparse_parameters( self, ) -> Tuple[Iterable[nn.Parameter], Iterable[nn.Parameter]]: diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py new file mode 100644 index 000000000..b9c3c8800 --- /dev/null +++ b/tzrec/models/sid_rqkmeans.py @@ -0,0 +1,343 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidRqkmeans: SID generation model using residual K-Means. + +Training is FAISS-only: ``predict`` collects embeddings into a CPU +buffer; the actual FAISS fit is triggered ONCE after the train_eval +loop ends, via the :meth:`BaseModel.on_train_end` lifecycle hook +(``tzrec.main`` calls ``_model.on_train_end()`` unconditionally). +""" + +from typing import Any, Dict, List, Optional + +import torch +import torch.distributed as dist +import torchmetrics +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.sid_model import BaseSidModel +from tzrec.modules.sid.kmeans import recon_diagnostics +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils import config_util +from tzrec.utils.logging_util import logger + + +def _coerce_proto_numbers(d: Dict) -> Dict: + """Coerce float-typed integers back to int. + + ``google.protobuf.Struct.number_value`` is always float, but most + ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require + Python ``int``. This helper converts any float that is an exact + integer to ``int`` for downstream consumption. + """ + out: Dict = {} + for k, v in d.items(): + if isinstance(v, float) and v.is_integer(): + out[k] = int(v) + else: + out[k] = v + return out + + +class SidRqkmeans(BaseSidModel): + """SID generation model using residual K-Means (FAISS-only). + + No gradient-based training. The codebook is built once at the end + of the train_eval loop via a single FAISS K-Means pass over the + embeddings collected during training. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config # SidRqkmeans proto message + + # config_to_kwargs returns Struct numbers as floats (it is + # MessageToDict under the hood), so _coerce_proto_numbers restores + # the ints faiss.Kmeans expects (niter, seed, nredo, ...). + self._faiss_kwargs = ( + _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) + if cfg.HasField("faiss_kmeans_kwargs") + else {} + ) + + self._quantizer = ResidualKMeansQuantizer( + embed_dim=self._input_dim, + n_layers=self._n_layers, + n_embed=self._n_embed_list, + normalize_residuals=self._normalize_residuals, + faiss_kmeans_kwargs=self._faiss_kwargs, + ) + + # Per-rank reservoir cap. FAISS K-Means only ever consumes + # K * max_points_per_centroid points (it subsamples internally), so + # buffering the full corpus is wasted memory. We reservoir-sample to + # that target instead, split across ranks so the gathered set on + # rank0 is ~train_sample_size and FAISS does no further subsampling. + # Use the LARGEST per-layer K so non-uniform codebooks (e.g. + # [256, 512, 1024]) still feed their biggest layer enough points. + k = max(self._n_embed_list) + max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) + global_target = ( + cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc + ) + world_size = dist.get_world_size() if dist.is_initialized() else 1 + self._sample_cap = max(1, -(-global_target // world_size)) # ceil div + + # Bounded host-resident reservoir (allocated lazily on first batch, + # once the embedding dim/device is known). ``_n_filled`` slots hold + # data; ``_n_seen`` is the running count for the sampling probability. + self._reservoir: Optional[torch.Tensor] = None + self._n_filled = 0 + self._n_seen = 0 + + # KMeans has no learnable parameters (centroids use register_buffer). + # Add dummy param to keep optimizer/DDP happy. + self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + + @torch.no_grad() + def _reservoir_add(self, x: torch.Tensor) -> None: + """Add a batch to the bounded reservoir (Vitter's Algorithm R). + + Keeps a uniform random ``self._sample_cap`` subset of every embedding + seen so far in O(cap) host memory, in a single streaming pass. + + Args: + x (Tensor): a batch of embeddings, shape (B, D); copied to host. + """ + x = x.detach().to("cpu", dtype=torch.float32) + cap = self._sample_cap + if self._reservoir is None: + self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) + + # Phase 1: fill empty slots first. + if self._n_filled < cap: + take = min(x.shape[0], cap - self._n_filled) + self._reservoir[self._n_filled : self._n_filled + take] = x[:take] + self._n_filled += take + self._n_seen += take + x = x[take:] + if x.shape[0] == 0: + return + + # Phase 2: replacement. Row j (0-indexed in x) is the + # (n_seen + j)-th item seen; it enters the reservoir with prob + # cap / (n_seen + j + 1), displacing a uniformly-random slot. + r = x.shape[0] + pos = self._n_seen + torch.arange(r) + accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) + idx = accept.nonzero(as_tuple=True)[0] + if idx.numel() > 0: + slots = torch.randint(0, cap, (idx.numel(),)) + # Intra-batch slot collisions resolve last-write-wins; the bias is + # O(B/cap) per step and negligible for codebook fitting. + self._reservoir[slots] = x[idx] + self._n_seen += r + + def _reservoir_sample(self) -> torch.Tensor: + """Return the filled portion of the reservoir, shape (n_filled, D).""" + if self._reservoir is None or self._n_filled == 0: + return torch.empty(0, self._input_dim, dtype=torch.float32) + return self._reservoir[: self._n_filled] + + def _reset_reservoir(self) -> None: + """Drop the reservoir after the FAISS fit to free host memory.""" + self._reservoir = None + self._n_filled = 0 + self._n_seen = 0 + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Training: buffer embeddings only (codes are dummy until FAISS fits). + Eval/inference (after ``on_train_end``): real predict + lookup. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + embedding = self._extract_feature(batch) + + # Training: reservoir-sample into a bounded host buffer for the + # end-of-loop FAISS fit, and return dummy codes — the codebook does + # not exist yet. The reservoir caps memory at _sample_cap rows + # regardless of corpus size (FAISS only consumes a subset anyway). + if self.is_train: + self._reservoir_add(embedding) + B = embedding.shape[0] + return { + "codes": torch.zeros( + B, self._n_layers, dtype=torch.long, device=embedding.device + ) + } + + codes, quantized = self._quantizer(embedding) + + predictions: Dict[str, torch.Tensor] = { + "codes": codes, + } + + if self.is_eval: + predictions["quantized"] = quantized + predictions["input_embedding"] = embedding + + return predictions + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute loss of the model. + + Returns zero loss to keep TrainWrapper backward happy. + _dummy_param * 0.0 ensures a compute graph exists so DDP + does not complain about unused parameters. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor. + """ + return {"dummy_loss": self._dummy_param.sum() * 0.0} + + def init_metric(self) -> None: + """Initialize metric modules (shared eval metrics + rel_loss). + + Only eval metrics are registered. During training ``predict`` + returns dummy zero codes (the codebook does not exist yet), so + any train-time metric would be either NaN or trivially constant; + the inherited no-op ``update_train_metric`` keeps the train path + empty (``compute_train_metric`` then returns an empty dict, which + the framework already tolerates). + """ + super().init_metric() + self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update metric state. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + if "input_embedding" in predictions: + _, rel = recon_diagnostics( + predictions["input_embedding"], + predictions["quantized"], + ) + # MeanSquaredError aggregates (preds, target) itself; rel_loss has + # no torchmetrics equivalent so it stays a MeanMetric. + self._metric_modules["mse"].update( + predictions["quantized"], predictions["input_embedding"] + ) + self._metric_modules["rel_loss"].update(rel) + + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + + @torch.no_grad() + def on_train_end(self) -> None: + """Trigger one-shot FAISS fit after the train_eval loop ends. + + Overrides :meth:`BaseModel.on_train_end`. Called unconditionally + by ``tzrec.main.train_and_evaluate`` after the training loop exits. + + DDP behavior: + - rank0: receive each rank's reservoir sample via gather_object, + concat, run FAISS fit, then broadcast centroids to all ranks. + - other ranks: ship their reservoir sample via gather_object + (dst=0) and wait for the broadcast. + + No cross-rank empty-buffer handshake is needed: the dataset layer + enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` + raises otherwise), so in synchronized training every rank receives + at least one shard and reaches the gather with a non-empty sample. + """ + is_ddp = ( + dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + ) + + local = self._reservoir_sample() + self._reset_reservoir() + + if is_ddp: + # DDP path: every rank ships its reservoir sample to rank 0 via + # gather_object. Each sample is bounded by _sample_cap, so the + # gathered set on rank0 is ~train_sample_size and FAISS does no + # further subsampling. + rank = dist.get_rank() + gathered: Optional[List[Optional[torch.Tensor]]] = ( + [None] * dist.get_world_size() if rank == 0 else None + ) + dist.gather_object(local, gathered, dst=0) + del local + if rank == 0: + assert gathered is not None + full = torch.cat([g for g in gathered if g is not None], dim=0) + del gathered + logger.info( + "[SidRqkmeans.on_train_end] rank0 fitting FAISS " + "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + ) + self._quantizer.train_offline(full, verbose=True) + del full + # Broadcast centroids and set the init flag locally on every + # rank. ``_is_initialized`` is a bool buffer and NCCL's bool + # dtype support is inconsistent across versions, so we avoid + # a separate broadcast for it — all ranks enter this block in + # lockstep, so a local fill_() keeps state consistent. + for layer in self._quantizer.layers: + dist.broadcast(layer.centroids, src=0) + layer._is_initialized.fill_(True) + dist.barrier() + return + + # Single-process path. Guard an empty sample with a plain local check + # (no collective): on_train_end may be invoked without a training pass. + if local.shape[0] == 0: + logger.warning( + "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " + "fit. Did the train_eval loop run?" + ) + return + + logger.info( + "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." + % (local.shape[0], local.shape[1]) + ) + self._quantizer.train_offline(local, verbose=True) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py new file mode 100644 index 000000000..8b224afac --- /dev/null +++ b/tzrec/models/sid_rqkmeans_test.py @@ -0,0 +1,306 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.models.sid_rqkmeans import SidRqkmeans +from tzrec.protos import model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils import misc_util +from tzrec.utils.state_dict_util import init_parameters + +WORLD_SIZE = 2 + + +def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: + """Create a minimal Batch with dense embedding features.""" + dense_feature = KeyedTensor.from_tensor_list( + keys=["item_emb"], + tensors=[torch.randn(batch_size, input_dim, device=device)], + ) + return Batch( + dense_features={BASE_DATA_GROUP: dense_feature}, + sparse_features={}, + labels={}, + ) + + +def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmeans: + """Build a SidRqkmeans configured for offline FAISS fit. + + Module-level (not a method) so the spawned DDP workers below can build + the same model; callers move it to a device / init params as needed. + SID models read the item-embedding dense feature directly from the batch + and do not consume feature_groups, so none is set. + """ + from google.protobuf.struct_pb2 import Struct + + n_embed_list = codebook if codebook is not None else [16] * n_layers + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_list, + normalize_residuals=False, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + ) + return SidRqkmeans( + model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), + features=[], + labels=[], + ) + + +class SidRqkmeansOfflineTest(unittest.TestCase): + """Single-process tests for SidRqkmeans (FAISS-only).""" + + def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): + """Create a SidRqkmeans on CPU with params initialized.""" + model = _build_model(input_dim, n_layers, niter, codebook) + init_parameters(model, device=torch.device("cpu")) + return model + + def test_proto_parse(self) -> None: + """Verify faiss_kmeans_kwargs are parsed correctly.""" + model = self._create_model() + self.assertEqual(model._faiss_kwargs.get("niter"), 5) + self.assertEqual(model._faiss_kwargs.get("seed"), 1234) + self.assertFalse(model._faiss_kwargs.get("verbose")) + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) + + def test_predict_collects_buffer(self) -> None: + """In train mode, predict reservoir-samples; never fits.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + for _ in range(4): + batch = _make_batch(B, input_dim) + preds = model.predict(batch) + self.assertIn("codes", preds) + + # Reservoir holds all 4*B samples (well under the cap) and tracks + # the running count. + self.assertEqual(model._n_seen, 4 * B) + self.assertEqual(model._n_filled, 4 * B) + # FAISS not yet triggered: layers should be uninitialized + for layer in model._quantizer.layers: + self.assertFalse(layer.is_initialized) + + def test_reservoir_caps_memory(self) -> None: + """Reservoir bounds the buffer at _sample_cap regardless of corpus.""" + B, input_dim = 16, 8 + model = self._create_model(input_dim=input_dim) + model._sample_cap = 10 # force a tiny cap + model._reset_reservoir() + model.train() + for _ in range(20): # 320 rows >> cap + model.predict(_make_batch(B, input_dim)) + self.assertEqual(model._n_seen, 20 * B) + self.assertEqual(model._n_filled, 10) + self.assertEqual(model._reservoir.shape, (10, input_dim)) + + def test_on_train_end_runs_faiss(self) -> None: + """on_train_end triggers FAISS fit and clears buffer.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + + # Accumulate enough samples (FAISS K-Means needs at least K points) + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + self.assertGreater(model._n_seen, 0) + + # Trigger one-shot FAISS fit + model.on_train_end() + + # Reservoir should be released after the fit + self.assertEqual(model._n_seen, 0) + self.assertIsNone(model._reservoir) + # All layers should be initialized + centroids non-zero + for layer in model._quantizer.layers: + self.assertTrue(bool(layer._is_initialized.item())) + self.assertGreater(layer.centroids.abs().sum().item(), 0.0) + + # After fit, predict on eval should produce valid codes + model.eval() + preds = model.predict(_make_batch(B, input_dim)) + codes = preds["codes"] + self.assertEqual(codes.shape, (B, 2)) + self.assertTrue((codes >= 0).all() and (codes < 16).all()) + + def test_non_uniform_codebook_end_to_end(self) -> None: + """Non-uniform codebook [8, 4, 16]: fit then emit per-layer codes.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + codebook = [8, 4, 16] + model = self._create_model(input_dim=input_dim, codebook=codebook) + # Reservoir cap derives from the LARGEST K (16), not the first (8). + self.assertEqual( + model._sample_cap, + 16 * int(model._faiss_kwargs.get("max_points_per_centroid", 256)), + ) + + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + for k, layer in zip(codebook, model._quantizer.layers): + self.assertTrue(bool(layer._is_initialized.item())) + self.assertEqual(layer.centroids.shape[0], k) + + model.eval() + codes = model.predict(_make_batch(B, input_dim))["codes"] + self.assertEqual(codes.shape, (B, 3)) + for i, k in enumerate(codebook): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_on_train_end_noop_on_empty_buffer(self) -> None: + """on_train_end on an empty buffer is a warned no-op.""" + model = self._create_model() + model.on_train_end() # should not raise + + def test_post_fit_checkpoint_round_trips(self) -> None: + """Fit → save state_dict → load into fresh instance → predict. + + After loading, ``predict`` must return real (non-zero) codes — + the centroids and the ``_is_initialized`` flag both need to come + through the state_dict. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + src = self._create_model(input_dim=input_dim) + src.train() + for _ in range(8): + src.predict(_make_batch(B, input_dim)) + src.on_train_end() + sd = src.state_dict() + + dst = self._create_model(input_dim=input_dim) + dst.load_state_dict(sd) + dst.eval() + codes = dst.predict(_make_batch(B, input_dim))["codes"] + self.assertGreater( + codes.abs().sum().item(), + 0, + "post-fit checkpoint resume produced all-zero codes", + ) + + def test_mid_fit_checkpoint_rejected_on_load(self) -> None: + """Tampered state (_is_initialized=True + zero centroids) raises.""" + model = self._create_model() + sd = model.state_dict() + # Simulate a checkpoint that captured the flag mid-fit (before + # load_centroids_ ran): True flag, zero centroids. + layer0_prefix = next( + k.rsplit("._is_initialized", 1)[0] + for k in sd + if k.endswith("._is_initialized") + ) + sd[f"{layer0_prefix}._is_initialized"] = torch.tensor(True) + + fresh = self._create_model() + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + +# -------------------------------------------------------------------------- +# Distributed (multi-process) test for the DDP on_train_end path: the +# cross-rank gather_object -> FAISS fit -> broadcast sequence the in-process +# tests above cannot reach. NCCL on GPU when >=2 devices, else gloo/CPU. +# -------------------------------------------------------------------------- +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: + device = _init_dist(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) + model.train() + + torch.manual_seed(100 + rank) + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" + + # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. + model.on_train_end() + + for layer in model._quantizer.layers: + assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" + assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" + # Centroids were broadcast from rank0 -> must be bit-identical across ranks. + for layer in model._quantizer.layers: + cmin, cmax = layer.centroids.clone(), layer.centroids.clone() + dist.all_reduce(cmin, op=dist.ReduceOp.MIN) + dist.all_reduce(cmax, op=dist.ReduceOp.MAX) + assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" + + model.eval() + codes = model.predict(_make_batch(8, input_dim, device))["codes"] + assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" + assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" + dist.destroy_process_group() + + +class SidRqkmeansDistTest(unittest.TestCase): + """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" + + def test_on_train_end_ddp(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_on_train_end_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py new file mode 100644 index 000000000..0b6fe4255 --- /dev/null +++ b/tzrec/modules/sid/kmeans.py @@ -0,0 +1,222 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""K-Means utilities for the SID-generation stack. + +This module is the single home for torch-native K-Means code used by +SID models: + +* :class:`KMeansLayer` — per-layer centroid container used by + :class:`ResidualKMeansQuantizer`. Centroids are injected + by the FAISS backend via ``load_centroids_``; the only forward path + is ``predict``. +* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by + :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the + first training batch (same FAISS backend as the offline RQ-KMeans fit). +""" + +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn + + +def recon_diagnostics( + x: torch.Tensor, + out: torch.Tensor, + epsilon: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """MSE + relative-L1 reconstruction diagnostics. + + Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for + ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s + per-layer log line (which converts to Python floats via ``.item()``). + + Args: + x: ground-truth embedding, shape (B, D). + out: quantized reconstruction, shape (B, D). + epsilon: numerical stabilizer for the relative-L1 denominator. + + Returns: + mse: scalar ``((out - x) ** 2).mean()``. + rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. + """ + mse = ((out - x) ** 2).mean() + rel = ( + torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) + ).mean() + return mse, rel + + +@torch.no_grad() +def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Squared L2 distance between rows of ``x`` and ``y``. + + Args: + x (Tensor): data points, shape (N, D). + y (Tensor): centroids, shape (K, D). + + Returns: + Tensor: squared distances, shape (N, K). + + Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch + size and the full (N, K) product is small. Kept branch-free (no + data-dependent chunking on ``N``) so the predict forward stays + FX-traceable: torchrec's inference pipeline symbolically traces the + model, and a ``if N <= chunk_size`` on the traced batch dim raises a + ``torch.fx`` TraceError. + """ + x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) + y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) + + +@torch.no_grad() +def faiss_residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> List[torch.Tensor]: + """Residual K-Means warm-start via FAISS, one pass per layer. + + Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned + centroid, and repeats on the residual for every layer. Used by + :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook + from the first training batch — the same FAISS backend the offline + RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQ-VAE kmeans_init. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + device = samples.device + _, D = samples.shape + # Own a contiguous fp32 numpy copy we mutate in place to form residuals. + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] + for n_clusters in n_clusters_list: + kmeans = faiss.Kmeans(D, n_clusters, **kwargs) + kmeans.train(x) + centroids = kmeans.centroids.copy() # (K, D) + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = kmeans.index.search(x, 1) + x -= centroids[idx.ravel()] # residual, in place + return res_centers + + +class KMeansLayer(nn.Module): + """Single layer of a residual K-Means stack. + + Centroids are populated externally by ``load_centroids_`` (called per + layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` + is the only forward path. PyTorch state-dict keys are scoped by + attribute path (``layers..centroids``), so renaming the class + does not break existing checkpoints. + + Args: + n_clusters (int): number of clusters (codebook size). + n_features (int): feature dimension. + """ + + def __init__( + self, + n_clusters: int, + n_features: int, + ) -> None: + super().__init__() + self.n_clusters = n_clusters + self.n_features = n_features + + self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) + # Flipped by ``load_centroids_`` after the FAISS fit. Persistent + # so a normal post-fit checkpoint round-trips; mid-fit poisoning + # (True flag + still-zero centroids) is caught in _load_from_state_dict. + self.register_buffer("_is_initialized", torch.tensor(False)) + + @property + def is_initialized(self) -> bool: + """Whether centroids have been injected via ``load_centroids_``.""" + return self._is_initialized.item() + + @torch.no_grad() + def load_centroids_(self, centroids: torch.Tensor) -> None: + """Inject offline-trained centroids. + + Args: + centroids (Tensor): externally trained centroids, + shape (n_clusters, n_features). + """ + assert centroids.shape == self.centroids.shape, ( + f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " + f"got {tuple(centroids.shape)}" + ) + self.centroids.copy_( + centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) + ) + self._is_initialized.fill_(True) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) -> None: + """Reject mid-fit-checkpoint state dicts (True flag + zero centroids).""" + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ) + if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: + error_msgs.append( + f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " + "are all zero — checkpoint was likely taken mid-FAISS-fit. " + "Re-run on_train_end to produce a valid checkpoint." + ) + + @torch.no_grad() + def predict(self, batch: torch.Tensor) -> torch.Tensor: + """Assign points to nearest centroid. + + Args: + batch (Tensor): data points, shape (B, D). + + Returns: + Tensor: cluster indices, shape (B,). + """ + dists = _squared_euclidean_distance(batch, self.centroids) + return torch.argmin(dists, dim=-1) diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py new file mode 100644 index 000000000..8fed1f83a --- /dev/null +++ b/tzrec/modules/sid/kmeans_test.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.kmeans import ( + KMeansLayer, + _squared_euclidean_distance, + faiss_residual_kmeans, + recon_diagnostics, +) + + +class KmeansHelpersTest(unittest.TestCase): + """Tests for the K-Means helper functions.""" + + def test_recon_diagnostics_zero_on_identity(self) -> None: + x = torch.randn(8, 4) + mse, rel = recon_diagnostics(x, x.clone()) + self.assertAlmostEqual(mse.item(), 0.0, places=6) + self.assertAlmostEqual(rel.item(), 0.0, places=6) + + def test_squared_euclidean_distance(self) -> None: + x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + d = _squared_euclidean_distance(x, y) + self.assertEqual(d.shape, (2, 2)) + # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) + + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + self.assertEqual(centers[0].device, samples.device) + + +class KMeansLayerTest(unittest.TestCase): + """Tests for the single KMeansLayer.""" + + def test_uninitialized_by_default(self) -> None: + layer = KMeansLayer(n_clusters=4, n_features=3) + self.assertFalse(layer.is_initialized) + self.assertEqual(layer.centroids.abs().sum().item(), 0.0) + + def test_load_centroids_and_predict(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) + layer.load_centroids_(centroids) + self.assertTrue(layer.is_initialized) + + batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) + codes = layer.predict(batch) + torch.testing.assert_close(codes, torch.tensor([0, 1])) + + def test_load_centroids_shape_mismatch_raises(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaises(AssertionError): + layer.load_centroids_(torch.zeros(3, 2)) + + def test_mid_fit_checkpoint_rejected(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + sd = layer.state_dict() + # Simulate a mid-fit checkpoint: flag True but centroids still zero. + sd["_is_initialized"] = torch.tensor(True) + fresh = KMeansLayer(n_clusters=2, n_features=2) + with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): + fresh.load_state_dict(sd) + + def test_post_fit_checkpoint_round_trips(self) -> None: + layer = KMeansLayer(n_clusters=2, n_features=2) + layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) + fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh.load_state_dict(layer.state_dict()) + self.assertTrue(fresh.is_initialized) + torch.testing.assert_close(fresh.centroids, layer.centroids) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py new file mode 100644 index 000000000..505a1b1dc --- /dev/null +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -0,0 +1,248 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-layer residual K-Means: ResidualKMeansQuantizer. + +Training is FAISS-only: the codebook is built once via ``train_offline`` +over the full embedding matrix; ``forward`` is read-only (predict + lookup). +""" + +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer +from tzrec.utils.logging_util import logger + + +class ResidualKMeansQuantizer(ResidualQuantizer): + """Multi-layer residual K-Means with offline FAISS training. + + Each layer quantizes the residual from the previous layer: + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i = layer_i.predict(residual_i) + quantized_i = layer_i.centroids[code_i] + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}) + + Args: + embed_dim (int): feature dimension. + n_layers (int): number of residual quantization layers. + n_embed (int|List[int]): number of clusters per layer. Default: 256. + May differ per layer (non-uniform codebooks such as + ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a + separate ``faiss.Kmeans`` per layer. + normalize_residuals (bool): whether to L2-normalize residuals + before each layer. Default: False. + faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to + ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, + 'gpu': True, 'verbose': True, 'spherical': False}). + """ + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + normalize_residuals: bool = False, + faiss_kmeans_kwargs: Optional[Dict] = None, + ) -> None: + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) + self.faiss_kmeans_kwargs = dict(faiss_kmeans_kwargs or {}) + + self.layers = nn.ModuleList( + [ + KMeansLayer( + n_clusters=self.n_embed_list[i], + n_features=embed_dim, + ) + for i in range(n_layers) + ] + ) + + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + temperature: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Nearest-centroid assignment for one layer. + + Uninitialized layers (before ``train_offline``) return zeros, so the + residual walk is a no-op and the model stays callable. ``temperature`` + is unused (no soft assignment). + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + temperature (float): unused. + + Returns: + codes (Tensor): cluster indices, shape (B,). + quantized (Tensor): selected centroids, shape (B, D). + """ + layer = self.layers[layer_idx] + if not layer.is_initialized: + codes = torch.zeros( + residual.shape[0], dtype=torch.long, device=residual.device + ) + return codes, torch.zeros_like(residual) + codes = layer.predict(residual) + return codes, layer.centroids[codes] + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Assign codes per layer and sum the centroids. + + Codebook is read-only here; training happens in ``train_offline``. + Uninitialized layers contribute zeros (see :meth:`_quantize_layer`) so + the model is callable before the one-shot FAISS fit completes. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + codes (Tensor): cluster indices per layer, shape (B, n_layers). + quantized (Tensor): sum of quantized embeddings, shape (B, D). + """ + cluster_ids, quantized_sum, _ = self._residual_pass(input) + return cluster_ids, quantized_sum + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get centroid weights for a specific layer. + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: centroids, shape (n_embed, embed_dim). + """ + return self.layers[layer_idx].centroids + + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up codebook vectors via the layer's centroid table.""" + return self.layers[layer_idx].centroids[code_idx] + + @torch.no_grad() + def train_offline( + self, + inputs: Union[torch.Tensor, "np.ndarray"], + verbose: bool = True, + ) -> None: + """Train the multi-layer codebook via offline FAISS K-Means. + + FAISS consumes torch tensors directly (via ``faiss.contrib. + torch_utils``) — no numpy round-trips. The residual matrix stays a + host (CPU) tensor; when a faiss-gpu build is present, ``gpu=`` + moves only FAISS's internal, subsampled working set to the GPU, so we + never hold (N, D) in VRAM. On a faiss-cpu build it runs on CPU + unchanged. Either way the code path is identical. + + Args: + inputs: full embedding matrix, shape (N, D), ``torch.Tensor`` or + ``np.ndarray``. Copied once to an owned CPU float32 tensor; + the caller's input is not mutated. + verbose (bool): whether to print per-layer reconstruction + loss. Default: True. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + import faiss.contrib.torch_utils # noqa: F401 (torch tensor I/O) + except ImportError as e: + raise ImportError( + "faiss is required for ResidualKMeansQuantizer training. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + # Own a contiguous CPU float32 tensor we can update in place for + # residuals, without mutating the caller's input. + if isinstance(inputs, torch.Tensor): + assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = inputs.detach().to("cpu", torch.float32).contiguous().clone() + else: + assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() + N = x.shape[0] + out = torch.zeros_like(x) + + # Use FAISS GPU compute when a GPU build is available (data stays on + # host; FAISS streams only its subsampled training set to the device). + # An explicit ``gpu`` in faiss_kmeans_kwargs always wins. + kwargs = dict(self.faiss_kmeans_kwargs) + if "gpu" not in kwargs: + kwargs["gpu"] = ( + torch.cuda.current_device() + if faiss.get_num_gpus() > 0 and torch.cuda.is_available() + else False + ) + + # Chunk size for index.search to limit peak memory. + # 500K × 512 × 4B ≈ 1 GB per chunk. + SEARCH_CHUNK = 500_000 + + for layer_idx in range(self.n_layers): + if self.normalize_residuals: + x = F.normalize(x, dim=-1) + + # Fresh Kmeans per layer so each layer can use its own K + # (non-uniform codebooks supported). Index construction is a cheap + # O(K*D) allocation next to train(), so this is effectively free. + kmeans = faiss.Kmeans( + self.embed_dim, self.n_embed_list[layer_idx], **kwargs + ) + kmeans.train(x) + centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32).cpu() + + for start in range(0, N, SEARCH_CHUNK): + end = min(start + SEARCH_CHUNK, N) + _, idx = kmeans.index.search(x[start:end], 1) + idx = torch.as_tensor(idx, device="cpu").reshape(-1).long() + q = centroids[idx] # (chunk, D) + out[start:end] += q + x[start:end] -= q # residual + del idx, q + + if verbose: + logger.info( + "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", + layer_idx, + self._calc_loss(out + x, out), # x_in = out + residual + ) + + self.layers[layer_idx].load_centroids_(centroids) + if verbose: + logger.info( + "[ResidualKMeansQuantizer][offline_faiss] layer %d finished", + layer_idx, + ) + + @staticmethod + def _calc_loss( + x: torch.Tensor, out: torch.Tensor, epsilon: float = 1e-4 + ) -> Dict[str, float]: + """Reconstruction loss diagnostics (MSE + relative L1).""" + loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) + return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index c94cc545d..d23ef1cf5 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,6 +14,9 @@ import torch from torch import nn +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, @@ -142,5 +145,94 @@ def test_decode_codes_sum_and_dtype(self) -> None: self.assertEqual(recon16.dtype, torch.bfloat16) +class ResidualKMeansQuantizerTest(unittest.TestCase): + def test_is_subclass(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertIsInstance(rkq, ResidualQuantizer) + + def test_non_uniform_codebook_supported(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) + self.assertEqual(rkq.n_embed_list, [8, 4, 16]) + self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + + def test_forward_returns_zeros_before_fit(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) + codes, quantized = rkq(torch.randn(5, 4)) + self.assertEqual(codes.shape, (5, 2)) + self.assertEqual(quantized.shape, (5, 4)) + + def test_forward_is_fx_traceable(self) -> None: + """Predict forward must FX-trace. + + torchrec's inference pipeline symbolically traces the model, so the + per-batch distance path must be free of data-dependent control flow. + """ + import torch.fx as fx + + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + for layer in rkq.layers: # populate centroids -> is_initialized=True + layer.load_centroids_(torch.randn(8, 4)) + traced = fx.symbolic_trace(rkq) + x = torch.randn(5, 4) + c_eager, q_eager = rkq(x) + c_traced, q_traced = traced(x) + torch.testing.assert_close(c_traced, c_eager) + torch.testing.assert_close(q_traced, q_eager) + + def test_train_offline_non_uniform(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + n_embed = [8, 4, 16] + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(512, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + # Each layer fit its own K centroids; codes stay in per-layer range. + codes, _ = rkq(torch.randn(7, 4)) + self.assertEqual(codes.shape, (7, 3)) + for i, k in enumerate(n_embed): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_train_offline_then_decode(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + + codes, _ = rkq(torch.randn(5, 4)) + self.assertTrue((codes >= 0).all() and (codes < 8).all()) + recon = rkq.decode_codes(codes) # inherited from the base + self.assertEqual(recon.shape, (5, 4)) + + def test_forward_get_codes_consistent(self) -> None: + """Forward ids and get_codes both route through the shared walk.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + x = torch.randn(9, 4) + fwd_ids, fwd_quant = rkq(x) + torch.testing.assert_close(rkq.get_codes(x), fwd_ids) + # forward's residual-sum equals the centroid-sum reconstruction. + torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index bef2062ea..58b719a7a 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -5,6 +5,7 @@ import "tzrec/protos/models/rank_model.proto"; import "tzrec/protos/models/multi_task_rank.proto"; import "tzrec/protos/models/match_model.proto"; import "tzrec/protos/models/general_rank_model.proto"; +import "tzrec/protos/models/sid_model.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; @@ -76,6 +77,10 @@ message ModelConfig { TDM tdm = 400; RocketLaunching rocket_launching = 500; + + // SID generation models + // (600 is reserved for SidRqvae, arriving in the follow-up PR) + SidRqkmeans sid_rqkmeans = 601; } optional uint32 num_class = 2 [default = 1]; diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto new file mode 100644 index 000000000..065013614 --- /dev/null +++ b/tzrec/protos/models/sid_model.proto @@ -0,0 +1,31 @@ +syntax = "proto2"; +package tzrec.protos; + +import "google/protobuf/struct.proto"; + +message SidRqkmeans { + // Input embedding dimension (K-Means runs directly on raw embeddings, + // no encoder). + optional uint32 input_dim = 1 [default = 512]; + // Per-layer cluster counts, e.g. [256, 256, 256]. + // List length is the number of residual quantization layers. Entries + // may differ per layer (non-uniform codebooks such as [256, 512, 1024] + // are supported — the FAISS backend fits a separate ``faiss.Kmeans`` + // per layer). + repeated uint32 codebook = 3; + // L2-normalize residuals before each layer. + optional bool normalize_residuals = 4 [default = true]; + // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a + // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, + // spherical: false, seed: 1234}. + optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + // Target number of embeddings to reservoir-sample for the FAISS fit + // (global, across all ranks). Bounds host memory regardless of corpus + // size. 0 (the default) auto-derives it as K * max_points_per_centroid + // — exactly what FAISS subsamples to internally (default 256), so no + // training points are wasted. + optional uint32 train_sample_size = 6 [default = 0]; + + // Name of the item embedding feature inside the input Batch. + optional string embedding_feature_name = 40 [default = "item_emb"]; +} From c7f3a091fa3d02e4d3ed5824f6312759612bc30a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 02:59:22 +0000 Subject: [PATCH 036/129] [review] SID: drop forced tail-checkpoint after on_train_end Remove the `last_ckpt_step == i_step -> -1` override (and its stale comment) in the train loop's end-of-loop hook. The normal checkpoint cadence already persists the post-hook state. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index 8824e8373..9efb8d8e4 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -502,12 +502,8 @@ def _train_and_evaluate( # One-shot end-of-loop hook (default no-op). Some models do real work # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. Since that mutates model state, force the - # tail-save below to fire so the post-hook state is persisted even when - # the last in-loop checkpoint coincided with the final step. + # collected during training. _model.on_train_end() - if last_ckpt_step == i_step: - last_ckpt_step = -1 _log_train( i_step, From 61ec842c89f91b75eb45f87e2af20cce09e4bb69 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 05:48:00 +0000 Subject: [PATCH 037/129] [review] SID: address code-review findings on PR #539 - on_train_end() now returns is_ckpt_after_train; the tail save fires on `last_ckpt_step != i_step or is_ckpt_after_train`, so the fitted FAISS codebook is always persisted even when the last periodic checkpoint landed on the final step (main.py, model.py, sid_rqkmeans.py). (#1) - DDP on_train_end: wrap the rank0 FAISS fit in try/except and broadcast a fit-status flag so a rank0-only failure (or an empty reservoir) makes all ranks raise together instead of deadlocking on the centroid broadcast; correct the empty-reservoir docstring. (#2, #3) - KMeansLayer: cache is_initialized as a plain Python bool to drop a per-layer per-batch GPU->CPU .item() sync on the eval/predict path, kept in lockstep with the _is_initialized buffer. (#6) - _reservoir_add: copy only the kept rows to host instead of the whole batch every training step (keep float64 for n_seen exactness). (#7) - train_offline: per-layer fit-loss log now reports cumulative reconstruction of the original input (correct under normalize_residuals); align the module normalize_residuals default to True to match the proto. (#8, #10) - Drop dead faiss_residual_kmeans (RQ-VAE-only, lands in PR3) and its test; tidy _coerce_proto_numbers into a comprehension. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 8 +- tzrec/models/model.py | 10 ++- tzrec/models/sid_rqkmeans.py | 90 +++++++++++++------ tzrec/models/sid_rqkmeans_test.py | 10 ++- tzrec/modules/sid/kmeans.py | 70 +++------------ tzrec/modules/sid/kmeans_test.py | 17 ---- .../modules/sid/residual_kmeans_quantizer.py | 13 ++- 7 files changed, 106 insertions(+), 112 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index 9efb8d8e4..b71df9fe1 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -502,8 +502,10 @@ def _train_and_evaluate( # One-shot end-of-loop hook (default no-op). Some models do real work # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. - _model.on_train_end() + # collected during training. When the hook mutated state that must be + # persisted, it returns True so the tail save below fires even if the + # last in-loop checkpoint already landed on the final step. + is_ckpt_after_train = _model.on_train_end() _log_train( i_step, @@ -518,7 +520,7 @@ def _train_and_evaluate( summary_writer.close() if train_config.is_profiling: prof.stop() - if last_ckpt_step != i_step: + if last_ckpt_step != i_step or is_ckpt_after_train: ckpt_manager.save(i_step, model, optimizer, dataloader_state) if eval_dataloader is not None: _evaluate( diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 10fa8aae5..09ffa1f58 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,14 +150,20 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results - def on_train_end(self) -> None: + def on_train_end(self) -> bool: """Hook fired once after the train_eval loop exits. Default: no-op. Override in models that need one-shot end-of-loop work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS codebook from the embedding sample it collected during training. + + Returns: + is_ckpt_after_train (bool): whether the hook mutated model state + that must be persisted, so the train loop should force a final + checkpoint even when one was already saved at the last step. + Default ``False`` (no-op hooks change nothing). """ - pass + return False def sparse_parameters( self, diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index b9c3c8800..00aefa12b 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -44,13 +44,10 @@ def _coerce_proto_numbers(d: Dict) -> Dict: Python ``int``. This helper converts any float that is an exact integer to ``int`` for downstream consumption. """ - out: Dict = {} - for k, v in d.items(): - if isinstance(v, float) and v.is_integer(): - out[k] = int(v) - else: - out[k] = v - return out + return { + k: int(v) if isinstance(v, float) and v.is_integer() else v + for k, v in d.items() + } class SidRqkmeans(BaseSidModel): @@ -132,15 +129,17 @@ def _reservoir_add(self, x: torch.Tensor) -> None: Args: x (Tensor): a batch of embeddings, shape (B, D); copied to host. """ - x = x.detach().to("cpu", dtype=torch.float32) + x = x.detach() cap = self._sample_cap if self._reservoir is None: self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - # Phase 1: fill empty slots first. + # Phase 1: fill empty slots first. Copy only the rows we keep to host. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) - self._reservoir[self._n_filled : self._n_filled + take] = x[:take] + self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( + "cpu", dtype=torch.float32 + ) self._n_filled += take self._n_seen += take x = x[take:] @@ -149,7 +148,12 @@ def _reservoir_add(self, x: torch.Tensor) -> None: # Phase 2: replacement. Row j (0-indexed in x) is the # (n_seen + j)-th item seen; it enters the reservoir with prob - # cap / (n_seen + j + 1), displacing a uniformly-random slot. + # cap / (n_seen + j + 1), displacing a uniformly-random slot. The + # accept decision needs only counts (not embedding values), so we + # compute it on small host index tensors and copy ONLY the accepted + # rows to host — in steady state (reservoir full, n_seen >> cap) + # almost none are accepted, so the whole-batch GPU->CPU copy is + # avoided. float64 keeps (n_seen + j + 1) exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) @@ -158,7 +162,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: slots = torch.randint(0, cap, (idx.numel(),)) # Intra-batch slot collisions resolve last-write-wins; the bias is # O(B/cap) per step and negligible for codebook fitting. - self._reservoir[slots] = x[idx] + self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) self._n_seen += r def _reservoir_sample(self) -> torch.Tensor: @@ -271,7 +275,7 @@ def update_metric( self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) @torch.no_grad() - def on_train_end(self) -> None: + def on_train_end(self) -> bool: """Trigger one-shot FAISS fit after the train_eval loop ends. Overrides :meth:`BaseModel.on_train_end`. Called unconditionally @@ -283,10 +287,19 @@ def on_train_end(self) -> None: - other ranks: ship their reservoir sample via gather_object (dst=0) and wait for the broadcast. - No cross-rank empty-buffer handshake is needed: the dataset layer - enforces ``num_files >= world_size`` (``tzrec.datasets.dataset`` - raises otherwise), so in synchronized training every rank receives - at least one shard and reaches the gather with a non-empty sample. + Empty-reservoir handling: for any real-scale dataset every rank gets + a non-empty reservoir — the default ParquetDataset (``rebalance=True``) + splits rows across ``num_workers * world_size`` workers, so a rank only + ends up empty for a pathologically tiny corpus (``total_rows`` smaller + than that worker count). That degenerate case does not hang: rank0's + FAISS fit raises on too-few points and the fit-status broadcast below + makes every rank raise a coordinated ``RuntimeError`` instead. + + Returns: + is_ckpt_after_train (bool): ``True`` once the codebook has been + fitted here (the centroid buffers changed and must be persisted, + so the train loop forces a final checkpoint); ``False`` when the + fit was skipped (empty reservoir — nothing to persist). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 @@ -306,16 +319,39 @@ def on_train_end(self) -> None: ) dist.gather_object(local, gathered, dst=0) del local + fit_ok = True if rank == 0: assert gathered is not None - full = torch.cat([g for g in gathered if g is not None], dim=0) - del gathered - logger.info( - "[SidRqkmeans.on_train_end] rank0 fitting FAISS " - "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + try: + full = torch.cat([g for g in gathered if g is not None], dim=0) + del gathered + logger.info( + "[SidRqkmeans.on_train_end] rank0 fitting FAISS " + "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) + ) + self._quantizer.train_offline(full, verbose=True) + del full + except Exception as e: # noqa: BLE001 + # Swallow on rank0 only long enough to tell the peers — if + # we let it propagate here, ranks 1..N-1 would block forever + # on the centroid broadcast below with no sender. + fit_ok = False + logger.error( + "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e + ) + # Sync rank0's status to every rank (int flag, not bool — see the + # NCCL note below) so a rank0-only failure makes all ranks raise + # together instead of deadlocking on the centroid broadcast. + status = torch.tensor( + [1 if fit_ok else 0], + device=self._quantizer.layers[0].centroids.device, + ) + dist.broadcast(status, src=0) + if int(status.item()) == 0: + raise RuntimeError( + "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " + "see rank0 logs for the underlying error." ) - self._quantizer.train_offline(full, verbose=True) - del full # Broadcast centroids and set the init flag locally on every # rank. ``_is_initialized`` is a bool buffer and NCCL's bool # dtype support is inconsistent across versions, so we avoid @@ -324,8 +360,9 @@ def on_train_end(self) -> None: for layer in self._quantizer.layers: dist.broadcast(layer.centroids, src=0) layer._is_initialized.fill_(True) + layer._initialized = True dist.barrier() - return + return True # Single-process path. Guard an empty sample with a plain local check # (no collective): on_train_end may be invoked without a training pass. @@ -334,10 +371,11 @@ def on_train_end(self) -> None: "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " "fit. Did the train_eval loop run?" ) - return + return False logger.info( "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (local.shape[0], local.shape[1]) ) self._quantizer.train_offline(local, verbose=True) + return True diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 8b224afac..30e204116 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -133,8 +133,8 @@ def test_on_train_end_runs_faiss(self) -> None: model.predict(_make_batch(B, input_dim)) self.assertGreater(model._n_seen, 0) - # Trigger one-shot FAISS fit - model.on_train_end() + # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint + self.assertTrue(model.on_train_end()) # Reservoir should be released after the fit self.assertEqual(model._n_seen, 0) @@ -185,7 +185,8 @@ def test_non_uniform_codebook_end_to_end(self) -> None: def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() - model.on_train_end() # should not raise + # No fit happened, so no tail checkpoint is requested. + self.assertFalse(model.on_train_end()) # should not raise def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. @@ -266,7 +267,8 @@ def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. - model.on_train_end() + # Every rank fitted/received the codebook, so each requests a tail ckpt. + assert model.on_train_end(), f"rank{rank}: on_train_end should request ckpt" for layer in model._quantizer.layers: assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 0b6fe4255..ecc554aa5 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -18,12 +18,9 @@ :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. -* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by - :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the - first training batch (same FAISS backend as the offline RQ-KMeans fit). """ -from typing import Dict, List, Optional, Tuple +from typing import Tuple import torch from torch import nn @@ -79,57 +76,6 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) -@torch.no_grad() -def faiss_residual_kmeans( - samples: torch.Tensor, - n_clusters_list: List[int], - faiss_kmeans_kwargs: Optional[Dict] = None, -) -> List[torch.Tensor]: - """Residual K-Means warm-start via FAISS, one pass per layer. - - Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned - centroid, and repeats on the residual for every layer. Used by - :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook - from the first training batch — the same FAISS backend the offline - RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - - Args: - samples (Tensor): data points, shape (N, D). - n_clusters_list (List[int]): per-layer cluster counts. - faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` - (e.g. ``{'niter': 10, 'seed': 123}``). - - Returns: - List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. - - Raises: - ImportError: if ``faiss`` is not installed. - """ - try: - import faiss - except ImportError as e: - raise ImportError( - "faiss is required for RQ-VAE kmeans_init. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - kwargs = dict(faiss_kmeans_kwargs or {}) - device = samples.device - _, D = samples.shape - # Own a contiguous fp32 numpy copy we mutate in place to form residuals. - x = samples.detach().cpu().float().numpy().copy() - - res_centers: List[torch.Tensor] = [] - for n_clusters in n_clusters_list: - kmeans = faiss.Kmeans(D, n_clusters, **kwargs) - kmeans.train(x) - centroids = kmeans.centroids.copy() # (K, D) - res_centers.append(torch.from_numpy(centroids).to(device)) - _, idx = kmeans.index.search(x, 1) - x -= centroids[idx.ravel()] # residual, in place - return res_centers - - class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. @@ -158,11 +104,17 @@ def __init__( # so a normal post-fit checkpoint round-trips; mid-fit poisoning # (True flag + still-zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) + # Plain-Python mirror of ``_is_initialized``, read on the per-batch + # forward path (``_quantize_layer``) so the hot path never pays a + # ``.item()`` GPU->CPU sync. Kept in lockstep with the buffer wherever + # the buffer changes: ``load_centroids_``, ``_load_from_state_dict``, + # and the DDP broadcast in ``SidRqkmeans.on_train_end``. + self._initialized: bool = False @property def is_initialized(self) -> bool: """Whether centroids have been injected via ``load_centroids_``.""" - return self._is_initialized.item() + return self._initialized @torch.no_grad() def load_centroids_(self, centroids: torch.Tensor) -> None: @@ -180,6 +132,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) self._is_initialized.fill_(True) + self._initialized = True def _load_from_state_dict( self, @@ -201,7 +154,10 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ) - if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: + # Mirror the restored buffer into the cached Python flag (one sync at + # load time, off the hot path). + self._initialized = bool(self._is_initialized.item()) + if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " "are all zero — checkpoint was likely taken mid-FAISS-fit. " diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index 8fed1f83a..cb86a39d8 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -16,7 +16,6 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, _squared_euclidean_distance, - faiss_residual_kmeans, recon_diagnostics, ) @@ -38,22 +37,6 @@ def test_squared_euclidean_distance(self) -> None: # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - def test_faiss_residual_kmeans_per_layer_centers(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - samples = torch.randn(512, 6) - centers = faiss_residual_kmeans( - samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} - ) - self.assertEqual(len(centers), 2) - self.assertEqual(centers[0].shape, (8, 6)) - self.assertEqual(centers[1].shape, (4, 6)) - self.assertTrue(torch.isfinite(centers[0]).all()) - self.assertEqual(centers[0].device, samples.device) - class KMeansLayerTest(unittest.TestCase): """Tests for the single KMeansLayer.""" diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 505a1b1dc..72a539654 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -49,7 +49,8 @@ class ResidualKMeansQuantizer(ResidualQuantizer): ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a separate ``faiss.Kmeans`` per layer. normalize_residuals (bool): whether to L2-normalize residuals - before each layer. Default: False. + before each layer. Default: True, matching the ``SidRqkmeans`` + proto default so direct instantiation agrees with the config path. faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, 'gpu': True, 'verbose': True, 'spherical': False}). @@ -60,7 +61,7 @@ def __init__( embed_dim: int, n_layers: int, n_embed: Union[int, List[int]] = 256, - normalize_residuals: bool = False, + normalize_residuals: bool = True, faiss_kmeans_kwargs: Optional[Dict] = None, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) @@ -187,6 +188,12 @@ def train_offline( x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() N = x.shape[0] out = torch.zeros_like(x) + # Keep the original input only when we log: the per-layer diagnostic + # is the cumulative reconstruction error of the *original* input by + # the centroid sum so far (the same quantity update_metric reports). + # ``out + x`` would equal it only when normalize_residuals is off; with + # normalization the residual is rescaled each layer, so track x0. + x0 = x.clone() if verbose else None # Use FAISS GPU compute when a GPU build is available (data stays on # host; FAISS streams only its subsampled training set to the device). @@ -229,7 +236,7 @@ def train_offline( logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, - self._calc_loss(out + x, out), # x_in = out + residual + self._calc_loss(x0, out), # cumulative recon of original input ) self.layers[layer_idx].load_centroids_(centroids) From 753f3fe94cd94ce05d819b8425fd76baafcf6959 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 06:04:30 +0000 Subject: [PATCH 038/129] [review] SID: default normalize_residuals to False Flip the default for RQ-KMeans residual normalization to False, in both the SidRqkmeans proto field and the ResidualKMeansQuantizer constructor (kept consistent to avoid the proto/module mismatch). This matches OpenOneRec's residual k-means, which fits raw residuals with no per-layer L2 normalization. Configs that set normalize_residuals explicitly are unaffected. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 7 ++++--- tzrec/protos/models/sid_model.proto | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 72a539654..a3bfb1dae 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -49,8 +49,9 @@ class ResidualKMeansQuantizer(ResidualQuantizer): ``[256, 512, 1024]`` are supported) — ``train_offline`` builds a separate ``faiss.Kmeans`` per layer. normalize_residuals (bool): whether to L2-normalize residuals - before each layer. Default: True, matching the ``SidRqkmeans`` - proto default so direct instantiation agrees with the config path. + before each layer. Default: False, matching the ``SidRqkmeans`` + proto default (and OpenOneRec's residual k-means, which fits raw + residuals with no per-layer normalization). faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, 'gpu': True, 'verbose': True, 'spherical': False}). @@ -61,7 +62,7 @@ def __init__( embed_dim: int, n_layers: int, n_embed: Union[int, List[int]] = 256, - normalize_residuals: bool = True, + normalize_residuals: bool = False, faiss_kmeans_kwargs: Optional[Dict] = None, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 065013614..6c3d1b297 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -14,7 +14,7 @@ message SidRqkmeans { // per layer). repeated uint32 codebook = 3; // L2-normalize residuals before each layer. - optional bool normalize_residuals = 4 [default = true]; + optional bool normalize_residuals = 4 [default = false]; // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, // spherical: false, seed: 1234}. From 52c745224431ac5e32969a68db4bdc64028b6467 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 07:47:17 +0000 Subject: [PATCH 039/129] [review] SID: encapsulation, comment, and import cleanups - KMeansLayer: add mark_initialized_() so the buffer + cached-bool init flag is owned by the layer; the DDP broadcast in SidRqkmeans uses it instead of poking the private fields. - SidRqkmeans: extract the reservoir-cap setup into _init_reservoir(). - residual_kmeans_quantizer: import faiss at module level (it's a pinned requirement) instead of a lazy in-function import; narrow train_offline(inputs) to torch.Tensor (all callers pass tensors) and drop the dead numpy branch. - Tighten the verbose comments/docstrings across the SID files. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 8 +- tzrec/models/model.py | 12 +- tzrec/models/sid_rqkmeans.py | 156 +++++++----------- tzrec/modules/sid/kmeans.py | 49 +++--- .../modules/sid/residual_kmeans_quantizer.py | 72 +++----- 5 files changed, 114 insertions(+), 183 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index b71df9fe1..8b4b5357b 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -500,11 +500,9 @@ def _train_and_evaluate( if lr.by_epoch: lr.step() - # One-shot end-of-loop hook (default no-op). Some models do real work - # here — e.g. SidRqkmeans fits its FAISS codebook from the embeddings - # collected during training. When the hook mutated state that must be - # persisted, it returns True so the tail save below fires even if the - # last in-loop checkpoint already landed on the final step. + # One-shot end-of-loop hook (default no-op; e.g. SidRqkmeans fits its FAISS + # codebook here). Returns True if it mutated persistable state, forcing the + # tail save below even when the last in-loop checkpoint hit the final step. is_ckpt_after_train = _model.on_train_end() _log_train( diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 09ffa1f58..c6b2b952c 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -153,15 +153,13 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: def on_train_end(self) -> bool: """Hook fired once after the train_eval loop exits. - Default: no-op. Override in models that need one-shot end-of-loop - work — e.g. :class:`SidRqkmeans` uses this hook to fit the FAISS - codebook from the embedding sample it collected during training. + Default no-op; override for one-shot end-of-loop work (e.g. + :class:`SidRqkmeans` fits its FAISS codebook here). Returns: - is_ckpt_after_train (bool): whether the hook mutated model state - that must be persisted, so the train loop should force a final - checkpoint even when one was already saved at the last step. - Default ``False`` (no-op hooks change nothing). + is_ckpt_after_train (bool): whether the hook mutated state that must + be persisted, so the loop forces a final checkpoint even if one was + already saved at the last step. Default ``False``. """ return False diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 00aefa12b..3859a3ef0 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -37,12 +37,10 @@ def _coerce_proto_numbers(d: Dict) -> Dict: - """Coerce float-typed integers back to int. + """Coerce whole-valued floats back to int. - ``google.protobuf.Struct.number_value`` is always float, but most - ``faiss.Kmeans`` kwargs (``niter``, ``seed``, ``nredo``, ...) require - Python ``int``. This helper converts any float that is an exact - integer to ``int`` for downstream consumption. + ``Struct.number_value`` is always float, but faiss.Kmeans kwargs + (``niter``, ``seed``, ...) need ``int``. """ return { k: int(v) if isinstance(v, float) and v.is_integer() else v @@ -76,9 +74,7 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message - # config_to_kwargs returns Struct numbers as floats (it is - # MessageToDict under the hood), so _coerce_proto_numbers restores - # the ints faiss.Kmeans expects (niter, seed, nredo, ...). + # config_to_kwargs yields Struct numbers as floats; coerce back to int. self._faiss_kwargs = ( _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) if cfg.HasField("faiss_kmeans_kwargs") @@ -93,41 +89,41 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - # Per-rank reservoir cap. FAISS K-Means only ever consumes - # K * max_points_per_centroid points (it subsamples internally), so - # buffering the full corpus is wasted memory. We reservoir-sample to - # that target instead, split across ranks so the gathered set on - # rank0 is ~train_sample_size and FAISS does no further subsampling. - # Use the LARGEST per-layer K so non-uniform codebooks (e.g. - # [256, 512, 1024]) still feed their biggest layer enough points. + self._init_reservoir() + + # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. + self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) + + def _init_reservoir(self) -> None: + """Set up the bounded host reservoir for the end-of-loop FAISS fit. + + Per-rank cap: FAISS subsamples to K*max_points_per_centroid internally, + so reservoir-sample to that target (split across ranks) rather than + buffer the whole corpus. Use the largest per-layer K so non-uniform + codebooks still feed their biggest layer enough points. + """ k = max(self._n_embed_list) max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) - global_target = ( - cfg.train_sample_size if cfg.train_sample_size > 0 else k * max_ppc - ) + target = self._model_config.train_sample_size + global_target = target if target > 0 else k * max_ppc world_size = dist.get_world_size() if dist.is_initialized() else 1 self._sample_cap = max(1, -(-global_target // world_size)) # ceil div - # Bounded host-resident reservoir (allocated lazily on first batch, - # once the embedding dim/device is known). ``_n_filled`` slots hold - # data; ``_n_seen`` is the running count for the sampling probability. + # Allocated lazily on the first batch. _n_filled = used slots; + # _n_seen = running count for the accept prob. self._reservoir: Optional[torch.Tensor] = None self._n_filled = 0 self._n_seen = 0 - # KMeans has no learnable parameters (centroids use register_buffer). - # Add dummy param to keep optimizer/DDP happy. - self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) - @torch.no_grad() def _reservoir_add(self, x: torch.Tensor) -> None: - """Add a batch to the bounded reservoir (Vitter's Algorithm R). + """Stream a batch into the reservoir (Vitter Algorithm R). - Keeps a uniform random ``self._sample_cap`` subset of every embedding - seen so far in O(cap) host memory, in a single streaming pass. + Keeps a uniform ``_sample_cap`` sample of all embeddings seen, in + O(cap) host memory. Args: - x (Tensor): a batch of embeddings, shape (B, D); copied to host. + x (Tensor): batch of embeddings, shape (B, D). """ x = x.detach() cap = self._sample_cap @@ -146,22 +142,17 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if x.shape[0] == 0: return - # Phase 2: replacement. Row j (0-indexed in x) is the - # (n_seen + j)-th item seen; it enters the reservoir with prob - # cap / (n_seen + j + 1), displacing a uniformly-random slot. The - # accept decision needs only counts (not embedding values), so we - # compute it on small host index tensors and copy ONLY the accepted - # rows to host — in steady state (reservoir full, n_seen >> cap) - # almost none are accepted, so the whole-batch GPU->CPU copy is - # avoided. float64 keeps (n_seen + j + 1) exact past 2**24. + # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random + # slot. The accept decision needs only counts, so compute it on host and + # copy ONLY accepted rows (in steady state, almost none) — avoiding the + # whole-batch GPU->CPU copy. float64 keeps n_seen+j+1 exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) idx = accept.nonzero(as_tuple=True)[0] if idx.numel() > 0: slots = torch.randint(0, cap, (idx.numel(),)) - # Intra-batch slot collisions resolve last-write-wins; the bias is - # O(B/cap) per step and negligible for codebook fitting. + # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) self._n_seen += r @@ -191,10 +182,8 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """ embedding = self._extract_feature(batch) - # Training: reservoir-sample into a bounded host buffer for the - # end-of-loop FAISS fit, and return dummy codes — the codebook does - # not exist yet. The reservoir caps memory at _sample_cap rows - # regardless of corpus size (FAISS only consumes a subset anyway). + # Training: just reservoir-sample for the end-of-loop FAISS fit and + # return dummy codes — the codebook does not exist yet. if self.is_train: self._reservoir_add(embedding) B = embedding.shape[0] @@ -221,9 +210,8 @@ def loss( ) -> Dict[str, torch.Tensor]: """Compute loss of the model. - Returns zero loss to keep TrainWrapper backward happy. - _dummy_param * 0.0 ensures a compute graph exists so DDP - does not complain about unused parameters. + Zero loss via ``_dummy_param * 0`` — gives TrainWrapper/DDP a compute + graph despite there being no real trainable params. Args: predictions (dict): a dict of predicted result. @@ -235,14 +223,11 @@ def loss( return {"dummy_loss": self._dummy_param.sum() * 0.0} def init_metric(self) -> None: - """Initialize metric modules (shared eval metrics + rel_loss). - - Only eval metrics are registered. During training ``predict`` - returns dummy zero codes (the codebook does not exist yet), so - any train-time metric would be either NaN or trivially constant; - the inherited no-op ``update_train_metric`` keeps the train path - empty (``compute_train_metric`` then returns an empty dict, which - the framework already tolerates). + """Register eval metrics (shared ``mse`` + ``rel_loss``). + + Train-time metrics are intentionally absent: ``predict`` returns dummy + codes pre-fit, so the inherited no-op ``update_train_metric`` keeps the + train path empty. """ super().init_metric() self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() @@ -265,8 +250,8 @@ def update_metric( predictions["input_embedding"], predictions["quantized"], ) - # MeanSquaredError aggregates (preds, target) itself; rel_loss has - # no torchmetrics equivalent so it stays a MeanMetric. + # mse aggregates (preds, target) itself; rel_loss has no + # torchmetrics equivalent, so it stays a MeanMetric. self._metric_modules["mse"].update( predictions["quantized"], predictions["input_embedding"] ) @@ -276,30 +261,20 @@ def update_metric( @torch.no_grad() def on_train_end(self) -> bool: - """Trigger one-shot FAISS fit after the train_eval loop ends. - - Overrides :meth:`BaseModel.on_train_end`. Called unconditionally - by ``tzrec.main.train_and_evaluate`` after the training loop exits. + """Fit the FAISS codebook once, after the train_eval loop exits. - DDP behavior: - - rank0: receive each rank's reservoir sample via gather_object, - concat, run FAISS fit, then broadcast centroids to all ranks. - - other ranks: ship their reservoir sample via gather_object - (dst=0) and wait for the broadcast. + Overrides :meth:`BaseModel.on_train_end` (called unconditionally by + ``tzrec.main``). DDP: every rank gather_objects its reservoir to rank0, + which fits and broadcasts the centroids back. - Empty-reservoir handling: for any real-scale dataset every rank gets - a non-empty reservoir — the default ParquetDataset (``rebalance=True``) - splits rows across ``num_workers * world_size`` workers, so a rank only - ends up empty for a pathologically tiny corpus (``total_rows`` smaller - than that worker count). That degenerate case does not hang: rank0's - FAISS fit raises on too-few points and the fit-status broadcast below - makes every rank raise a coordinated ``RuntimeError`` instead. + An empty reservoir only happens for a pathologically tiny corpus + (rebalance splits rows across ``num_workers * world_size``); it then + fails fast via the fit-status broadcast rather than hanging. Returns: - is_ckpt_after_train (bool): ``True`` once the codebook has been - fitted here (the centroid buffers changed and must be persisted, - so the train loop forces a final checkpoint); ``False`` when the - fit was skipped (empty reservoir — nothing to persist). + is_ckpt_after_train (bool): ``True`` if the codebook was fitted + (centroids changed → force a final checkpoint), ``False`` if the + fit was skipped (empty reservoir). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 @@ -309,10 +284,7 @@ def on_train_end(self) -> bool: self._reset_reservoir() if is_ddp: - # DDP path: every rank ships its reservoir sample to rank 0 via - # gather_object. Each sample is bounded by _sample_cap, so the - # gathered set on rank0 is ~train_sample_size and FAISS does no - # further subsampling. + # Each rank ships its (capped) reservoir to rank0, which fits. rank = dist.get_rank() gathered: Optional[List[Optional[torch.Tensor]]] = ( [None] * dist.get_world_size() if rank == 0 else None @@ -332,16 +304,14 @@ def on_train_end(self) -> bool: self._quantizer.train_offline(full, verbose=True) del full except Exception as e: # noqa: BLE001 - # Swallow on rank0 only long enough to tell the peers — if - # we let it propagate here, ranks 1..N-1 would block forever - # on the centroid broadcast below with no sender. + # Don't raise yet — peers would hang on the broadcast below. + # Signal failure via the status flag so all ranks raise. fit_ok = False logger.error( "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e ) - # Sync rank0's status to every rank (int flag, not bool — see the - # NCCL note below) so a rank0-only failure makes all ranks raise - # together instead of deadlocking on the centroid broadcast. + # Broadcast rank0's status (int, not bool — see NCCL note below) so + # a rank0-only failure makes all ranks raise instead of deadlocking. status = torch.tensor( [1 if fit_ok else 0], device=self._quantizer.layers[0].centroids.device, @@ -352,20 +322,16 @@ def on_train_end(self) -> bool: "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " "see rank0 logs for the underlying error." ) - # Broadcast centroids and set the init flag locally on every - # rank. ``_is_initialized`` is a bool buffer and NCCL's bool - # dtype support is inconsistent across versions, so we avoid - # a separate broadcast for it — all ranks enter this block in - # lockstep, so a local fill_() keeps state consistent. + # Broadcast centroids; set the init flag locally (avoids + # broadcasting a bool buffer — NCCL bool support is inconsistent). + # All ranks are in lockstep, so a local mark_initialized_() agrees. for layer in self._quantizer.layers: dist.broadcast(layer.centroids, src=0) - layer._is_initialized.fill_(True) - layer._initialized = True + layer.mark_initialized_() dist.barrier() return True - # Single-process path. Guard an empty sample with a plain local check - # (no collective): on_train_end may be invoked without a training pass. + # Single-process: guard an empty reservoir with a plain local check. if local.shape[0] == 0: logger.warning( "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index ecc554aa5..d6e34acd9 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -33,9 +33,8 @@ def recon_diagnostics( ) -> Tuple[torch.Tensor, torch.Tensor]: """MSE + relative-L1 reconstruction diagnostics. - Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for - ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s - per-layer log line (which converts to Python floats via ``.item()``). + Shared by :meth:`SidRqkmeans.update_metric` and + :meth:`ResidualKMeansQuantizer.train_offline`'s per-layer log. Args: x: ground-truth embedding, shape (B, D). @@ -64,12 +63,8 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Returns: Tensor: squared distances, shape (N, K). - Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch - size and the full (N, K) product is small. Kept branch-free (no - data-dependent chunking on ``N``) so the predict forward stays - FX-traceable: torchrec's inference pipeline symbolically traces the - model, and a ``if N <= chunk_size`` on the traced batch dim raises a - ``torch.fx`` TraceError. + Kept branch-free (no data-dependent control flow on ``N``) so the + per-batch predict forward stays FX-traceable for torchrec inference. """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) @@ -79,11 +74,9 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. - Centroids are populated externally by ``load_centroids_`` (called per - layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` - is the only forward path. PyTorch state-dict keys are scoped by - attribute path (``layers..centroids``), so renaming the class - does not break existing checkpoints. + Centroids are populated externally by ``load_centroids_`` (the FAISS + backend in :class:`ResidualKMeansQuantizer`); ``predict`` is the only + forward path. Args: n_clusters (int): number of clusters (codebook size). @@ -100,15 +93,12 @@ def __init__( self.n_features = n_features self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) - # Flipped by ``load_centroids_`` after the FAISS fit. Persistent - # so a normal post-fit checkpoint round-trips; mid-fit poisoning - # (True flag + still-zero centroids) is caught in _load_from_state_dict. + # Persistent so a post-fit checkpoint round-trips; a mid-fit poison + # (True flag + zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) - # Plain-Python mirror of ``_is_initialized``, read on the per-batch - # forward path (``_quantize_layer``) so the hot path never pays a - # ``.item()`` GPU->CPU sync. Kept in lockstep with the buffer wherever - # the buffer changes: ``load_centroids_``, ``_load_from_state_dict``, - # and the DDP broadcast in ``SidRqkmeans.on_train_end``. + # Plain-Python mirror of the buffer, read on the per-batch forward + # path to avoid a .item() GPU->CPU sync. Synced only via + # mark_initialized_ and _load_from_state_dict. self._initialized: bool = False @property @@ -116,6 +106,15 @@ def is_initialized(self) -> bool: """Whether centroids have been injected via ``load_centroids_``.""" return self._initialized + def mark_initialized_(self) -> None: + """Flag centroids populated, syncing buffer + cached mirror. + + For callers that fill ``centroids`` in place (e.g. the DDP broadcast + in :meth:`SidRqkmeans.on_train_end`) rather than via ``load_centroids_``. + """ + self._is_initialized.fill_(True) + self._initialized = True + @torch.no_grad() def load_centroids_(self, centroids: torch.Tensor) -> None: """Inject offline-trained centroids. @@ -131,8 +130,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: self.centroids.copy_( centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) - self._is_initialized.fill_(True) - self._initialized = True + self.mark_initialized_() def _load_from_state_dict( self, @@ -154,8 +152,7 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ) - # Mirror the restored buffer into the cached Python flag (one sync at - # load time, off the hot path). + # Mirror the restored buffer into the cached flag (one load-time sync). self._initialized = bool(self._is_initialized.item()) if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index a3bfb1dae..e28891f9c 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -17,7 +17,8 @@ from typing import Dict, List, Optional, Tuple, Union -import numpy as np +import faiss +import faiss.contrib.torch_utils # noqa: F401 (registers torch tensor I/O) import torch from torch import nn from torch.nn import functional as F @@ -144,61 +145,34 @@ def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: @torch.no_grad() def train_offline( self, - inputs: Union[torch.Tensor, "np.ndarray"], + inputs: torch.Tensor, verbose: bool = True, ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - FAISS consumes torch tensors directly (via ``faiss.contrib. - torch_utils``) — no numpy round-trips. The residual matrix stays a - host (CPU) tensor; when a faiss-gpu build is present, ``gpu=`` - moves only FAISS's internal, subsampled working set to the GPU, so we - never hold (N, D) in VRAM. On a faiss-cpu build it runs on CPU - unchanged. Either way the code path is identical. + The residual matrix stays a host (CPU) tensor; with a faiss-gpu build, + ``gpu=`` moves only FAISS's subsampled working set to the GPU, so + we never hold (N, D) in VRAM. faiss-cpu runs the same path on CPU. Args: - inputs: full embedding matrix, shape (N, D), ``torch.Tensor`` or - ``np.ndarray``. Copied once to an owned CPU float32 tensor; - the caller's input is not mutated. - verbose (bool): whether to print per-layer reconstruction - loss. Default: True. - - Raises: - ImportError: if ``faiss`` is not installed. + inputs (Tensor): embedding matrix (N, D). Copied once to an owned + CPU float32 tensor; not mutated. + verbose (bool): print per-layer reconstruction loss. Default: True. """ - try: - import faiss - import faiss.contrib.torch_utils # noqa: F401 (torch tensor I/O) - except ImportError as e: - raise ImportError( - "faiss is required for ResidualKMeansQuantizer training. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - # Own a contiguous CPU float32 tensor we can update in place for - # residuals, without mutating the caller's input. - if isinstance(inputs, torch.Tensor): - assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) - x = inputs.detach().to("cpu", torch.float32).contiguous().clone() - else: - assert inputs.ndim == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) - x = torch.from_numpy(np.ascontiguousarray(inputs, dtype=np.float32)).clone() + # Own a contiguous CPU float32 copy to update in place as the residual. + assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) + x = inputs.detach().to("cpu", torch.float32).contiguous().clone() N = x.shape[0] out = torch.zeros_like(x) - # Keep the original input only when we log: the per-layer diagnostic - # is the cumulative reconstruction error of the *original* input by - # the centroid sum so far (the same quantity update_metric reports). - # ``out + x`` would equal it only when normalize_residuals is off; with - # normalization the residual is rescaled each layer, so track x0. + # Original input, kept only for the log: the per-layer diagnostic is the + # cumulative recon error of x0 by the centroid sum (what update_metric + # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Use FAISS GPU compute when a GPU build is available (data stays on - # host; FAISS streams only its subsampled training set to the device). - # An explicit ``gpu`` in faiss_kmeans_kwargs always wins. + # Use FAISS GPU compute when a faiss-gpu build is present; an explicit + # ``gpu`` in faiss_kmeans_kwargs always wins. kwargs = dict(self.faiss_kmeans_kwargs) if "gpu" not in kwargs: kwargs["gpu"] = ( @@ -207,17 +181,15 @@ def train_offline( else False ) - # Chunk size for index.search to limit peak memory. - # 500K × 512 × 4B ≈ 1 GB per chunk. + # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 for layer_idx in range(self.n_layers): if self.normalize_residuals: x = F.normalize(x, dim=-1) - # Fresh Kmeans per layer so each layer can use its own K - # (non-uniform codebooks supported). Index construction is a cheap - # O(K*D) allocation next to train(), so this is effectively free. + # Fresh Kmeans per layer so each can use its own K (non-uniform + # codebooks). kmeans = faiss.Kmeans( self.embed_dim, self.n_embed_list[layer_idx], **kwargs ) From fbd973ffd1e391b1df4579dfe9987a9e90760704 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 07:56:52 +0000 Subject: [PATCH 040/129] [review] SID: move FAISS fit-sample sizing into the quantizer Add ResidualKMeansQuantizer.default_fit_sample_size() (max(K) * max_points_per_centroid) so the FAISS default lives in the FAISS-owning class; SidRqkmeans._init_reservoir asks the quantizer instead of reading faiss_kwargs and hardcoding 256. Behavior-identical. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 13 ++++++------- tzrec/modules/sid/residual_kmeans_quantizer.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3859a3ef0..9f29b4eac 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,15 +97,14 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: FAISS subsamples to K*max_points_per_centroid internally, - so reservoir-sample to that target (split across ranks) rather than - buffer the whole corpus. Use the largest per-layer K so non-uniform - codebooks still feed their biggest layer enough points. + Per-rank cap: target the points the FAISS fit will subsample to + (``ResidualKMeansQuantizer.default_fit_sample_size``), split across + ranks, rather than buffer the whole corpus. """ - k = max(self._n_embed_list) - max_ppc = int(self._faiss_kwargs.get("max_points_per_centroid", 256)) target = self._model_config.train_sample_size - global_target = target if target > 0 else k * max_ppc + global_target = ( + target if target > 0 else self._quantizer.default_fit_sample_size() + ) world_size = dist.get_world_size() if dist.is_initialized() else 1 self._sample_cap = max(1, -(-global_target // world_size)) # ceil div diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index e28891f9c..8e5960e99 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -142,6 +142,16 @@ def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's centroid table.""" return self.layers[layer_idx].centroids[code_idx] + def default_fit_sample_size(self) -> int: + """Points the FAISS fit subsamples to: max(K) * max_points_per_centroid. + + ``faiss.Kmeans`` caps each layer's training set at + ``K * max_points_per_centroid`` (default 256), so fitting on more is + wasted. Callers use this to size their training-sample reservoir. + """ + max_ppc = int(self.faiss_kmeans_kwargs.get("max_points_per_centroid", 256)) + return max(self.n_embed_list) * max_ppc + @torch.no_grad() def train_offline( self, From 893a62794b17fc5aaf22bd0b9710b6a26ef8bdf1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 8 Jun 2026 08:22:38 +0000 Subject: [PATCH 041/129] [review] SID: log rank0 FAISS-fit failure with traceback Use logger.exception() in on_train_end's rank0 except so the underlying error's stack trace is captured (peers raise a coordinated RuntimeError pointing at the rank0 log); drop the now-unused `as e`. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 9f29b4eac..3401f7568 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -302,12 +302,14 @@ def on_train_end(self) -> bool: ) self._quantizer.train_offline(full, verbose=True) del full - except Exception as e: # noqa: BLE001 + except Exception: # noqa: BLE001 # Don't raise yet — peers would hang on the broadcast below. # Signal failure via the status flag so all ranks raise. + # logger.exception keeps the traceback so the rank0-only + # failure is diagnosable from the log. fit_ok = False - logger.error( - "[SidRqkmeans.on_train_end] rank0 FAISS fit failed: %s", e + logger.exception( + "[SidRqkmeans.on_train_end] rank0 FAISS fit failed" ) # Broadcast rank0's status (int, not bool — see NCCL note below) so # a rank0-only failure makes all ranks raise instead of deadlocking. From 3734fc21dfee8d7ebc98b9b4e6d957ed84df8e11 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 02:36:08 +0000 Subject: [PATCH 042/129] [review] SID: clarify the reservoir ceil-div comment Comment-only; pushed to re-trigger CI. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3401f7568..25bd4f19a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -106,7 +106,8 @@ def _init_reservoir(self) -> None: target if target > 0 else self._quantizer.default_fit_sample_size() ) world_size = dist.get_world_size() if dist.is_initialized() else 1 - self._sample_cap = max(1, -(-global_target // world_size)) # ceil div + # ceil div: round up so the per-rank caps together cover global_target. + self._sample_cap = max(1, -(-global_target // world_size)) # Allocated lazily on the first batch. _n_filled = used slots; # _n_seen = running count for the accept prob. From 795c676569188719079e72183fe33083f06d1547 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 03:40:15 +0000 Subject: [PATCH 043/129] [review] SID: fix FAISS gpu kwarg + close test gaps from PR review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Should-fix: - train_offline: faiss reads `gpu` as a GPU *count*, not a device index, so `gpu=current_device()` was 0 (single-GPU / rank0) -> falsy -> silent CPU fallback. Pass `gpu=1` so the fit actually runs on the (rank0) GPU. Test gaps: - reservoir Phase-2 replacement correctness (identifiable rows: intact, in-range, replacement actually occurs) — beyond the count/shape checks. - normalize_residuals=True end-to-end through train_offline (the F.normalize site the other tests never reached). - eval vs inference predict contract (quantized/input_embedding vs codes-only) and the init_metric/update_metric/compute_metric path. - checkpoint round-trip now asserts codes match the source model exactly (assert_close), not merely non-zero. Minor docs: - on_train_end Returns: clarify only the single-process path returns False; DDP raises on an empty gather. - train_offline docstring: the post-fit index.search streams all N in chunks. - proto train_sample_size comment: K -> max(K) for non-uniform codebooks. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 7 +- tzrec/models/sid_rqkmeans_test.py | 169 ++++++++++++++++-- .../modules/sid/residual_kmeans_quantizer.py | 17 +- tzrec/protos/models/sid_model.proto | 7 +- 4 files changed, 173 insertions(+), 27 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 25bd4f19a..384c049a9 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -273,8 +273,11 @@ def on_train_end(self) -> bool: Returns: is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint), ``False`` if the - fit was skipped (empty reservoir). + (centroids changed → force a final checkpoint). Only the + single-process path can return ``False`` (empty reservoir, fit + skipped); the DDP path either returns ``True`` or raises (an empty + gather makes rank0's fit fail, which the status broadcast turns + into a coordinated ``RuntimeError``). """ is_ddp = ( dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 30e204116..2cfad30a1 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -27,12 +27,9 @@ WORLD_SIZE = 2 -def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: - """Create a minimal Batch with dense embedding features.""" - dense_feature = KeyedTensor.from_tensor_list( - keys=["item_emb"], - tensors=[torch.randn(batch_size, input_dim, device=device)], - ) +def _batch_from_rows(rows: torch.Tensor) -> Batch: + """Wrap explicit ``item_emb`` rows in a minimal Batch.""" + dense_feature = KeyedTensor.from_tensor_list(keys=["item_emb"], tensors=[rows]) return Batch( dense_features={BASE_DATA_GROUP: dense_feature}, sparse_features={}, @@ -40,7 +37,14 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: ) -def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmeans: +def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: + """Create a minimal Batch with random dense embedding features.""" + return _batch_from_rows(torch.randn(batch_size, input_dim, device=device)) + + +def _build_model( + input_dim=32, n_layers=2, niter=5, codebook=None, normalize_residuals=False +) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. Module-level (not a method) so the spawned DDP workers below can build @@ -56,7 +60,7 @@ def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmean cfg = sid_model_pb2.SidRqkmeans( input_dim=input_dim, codebook=n_embed_list, - normalize_residuals=False, + normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", ) @@ -70,9 +74,16 @@ def _build_model(input_dim=32, n_layers=2, niter=5, codebook=None) -> SidRqkmean class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" - def _create_model(self, input_dim=32, n_layers=2, niter=5, codebook=None): + def _create_model( + self, + input_dim=32, + n_layers=2, + niter=5, + codebook=None, + normalize_residuals=False, + ): """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model(input_dim, n_layers, niter, codebook) + model = _build_model(input_dim, n_layers, niter, codebook, normalize_residuals) init_parameters(model, device=torch.device("cpu")) return model @@ -117,6 +128,51 @@ def test_reservoir_caps_memory(self) -> None: self.assertEqual(model._n_filled, 10) self.assertEqual(model._reservoir.shape, (10, input_dim)) + def test_reservoir_phase2_replacement(self) -> None: + """Phase-2 replacement keeps a valid reservoir of real, in-range rows. + + Feeds identifiable rows (each row's value == its global stream index), + then asserts every reservoir slot still holds an intact fed row, all + indices are in range, and replacement past the initial fill actually + happened — exercising the accept-prob / slot-write logic that the + count/shape-only ``test_reservoir_caps_memory`` cannot. + """ + torch.manual_seed(0) + input_dim, cap, B, n_batches = 4, 8, 4, 50 + model = self._create_model(input_dim=input_dim) + model._sample_cap = cap + model._reset_reservoir() + model.train() + + gidx = 0 + for _ in range(n_batches): + rows = ( + torch.arange(gidx, gidx + B, dtype=torch.float32) + .unsqueeze(1) + .expand(B, input_dim) + .contiguous() + ) + gidx += B + model.predict(_batch_from_rows(rows)) + + total = B * n_batches + self.assertEqual(model._n_seen, total) + self.assertEqual(model._n_filled, cap) + + res = model._reservoir + idx = res[:, 0].round().long() + # Each stored row is an intact fed row (all columns equal its index), + # never zeros/garbage. + self.assertTrue( + torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), + "reservoir holds corrupted (non-fed) rows", + ) + # All indices are valid stream positions. + self.assertTrue((idx >= 0).all() and (idx < total).all()) + # Phase-2 replacement happened: at least one slot holds a row added + # after the reservoir filled (index >= cap). + self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + def test_on_train_end_runs_faiss(self) -> None: """on_train_end triggers FAISS fit and clears buffer.""" try: @@ -182,6 +238,83 @@ def test_non_uniform_codebook_end_to_end(self) -> None: for i, k in enumerate(codebook): self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + def test_normalize_residuals_end_to_end(self) -> None: + """train_offline with normalize_residuals=True fits + predicts. + + Exercises the ``F.normalize`` site inside ``train_offline`` (a second + normalize independent of ``_residual_pass``), which the other tests — + all built with normalize_residuals=False — never reach. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim, normalize_residuals=True) + self.assertTrue(model._quantizer.normalize_residuals) + + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + self.assertTrue(model.on_train_end()) + + for layer in model._quantizer.layers: + self.assertTrue(layer.is_initialized) + + model.eval() + codes = model.predict(_make_batch(B, input_dim))["codes"] + self.assertEqual(codes.shape, (B, 2)) + self.assertTrue((codes >= 0).all() and (codes < 16).all()) + + def test_eval_and_inference_predict_contract(self) -> None: + """Eval exposes quantized/input_embedding; inference is codes-only.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + # Eval mode: reconstruction outputs are present for update_metric. + model.eval() + eval_preds = model.predict(_make_batch(B, input_dim)) + self.assertIn("quantized", eval_preds) + self.assertIn("input_embedding", eval_preds) + + # Inference (serving) mode: codes-only contract. + model.set_is_inference(True) + inf_preds = model.predict(_make_batch(B, input_dim)) + self.assertEqual(set(inf_preds.keys()), {"codes"}) + + def test_eval_metric_path(self) -> None: + """init_metric/update_metric report finite mse + rel_loss in eval.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + B, input_dim = 64, 32 + model = self._create_model(input_dim=input_dim) + model.train() + for _ in range(8): + model.predict(_make_batch(B, input_dim)) + model.on_train_end() + + model.init_metric() + model.eval() + preds = model.predict(_make_batch(B, input_dim)) + model.update_metric(preds, _make_batch(B, input_dim)) + metrics = model.compute_metric() + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue(torch.isfinite(torch.as_tensor(metrics[key])).all()) + def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() @@ -191,9 +324,9 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. - After loading, ``predict`` must return real (non-zero) codes — - the centroids and the ``_is_initialized`` flag both need to come - through the state_dict. + The reloaded model must produce the *same* codes as the source on the + same batch — verifying the centroids round-trip exactly, not merely + that they came through as non-zero. """ try: import faiss # noqa: F401 @@ -210,13 +343,19 @@ def test_post_fit_checkpoint_round_trips(self) -> None: dst = self._create_model(input_dim=input_dim) dst.load_state_dict(sd) + + # Same batch through both → identical codes (exact round-trip). + batch = _make_batch(B, input_dim) + src.eval() dst.eval() - codes = dst.predict(_make_batch(B, input_dim))["codes"] + src_codes = src.predict(batch)["codes"] + dst_codes = dst.predict(batch)["codes"] self.assertGreater( - codes.abs().sum().item(), + dst_codes.abs().sum().item(), 0, "post-fit checkpoint resume produced all-zero codes", ) + torch.testing.assert_close(dst_codes, src_codes) def test_mid_fit_checkpoint_rejected_on_load(self) -> None: """Tampered state (_is_initialized=True + zero centroids) raises.""" diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 8e5960e99..83e095192 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -160,9 +160,12 @@ def train_offline( ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - The residual matrix stays a host (CPU) tensor; with a faiss-gpu build, - ``gpu=`` moves only FAISS's subsampled working set to the GPU, so - we never hold (N, D) in VRAM. faiss-cpu runs the same path on CPU. + The residual matrix stays a host (CPU) tensor. With a faiss-gpu build, + ``faiss.Kmeans`` runs the K-Means training (over its internally + subsampled set) on the GPU; the post-fit ``index.search`` assignment + still streams all N rows through in ``SEARCH_CHUNK``-sized chunks, so we + never hold the full (N, D) on the device. faiss-cpu runs the same path + on CPU. Args: inputs (Tensor): embedding matrix (N, D). Copied once to an owned @@ -182,13 +185,13 @@ def train_offline( x0 = x.clone() if verbose else None # Use FAISS GPU compute when a faiss-gpu build is present; an explicit - # ``gpu`` in faiss_kmeans_kwargs always wins. + # ``gpu`` in faiss_kmeans_kwargs always wins. NB faiss reads ``gpu`` as a + # GPU *count* (1 = one GPU = the current/rank0 device), not a device + # index — passing an index of 0 is falsy and silently falls back to CPU. kwargs = dict(self.faiss_kmeans_kwargs) if "gpu" not in kwargs: kwargs["gpu"] = ( - torch.cuda.current_device() - if faiss.get_num_gpus() > 0 and torch.cuda.is_available() - else False + 1 if (faiss.get_num_gpus() > 0 and torch.cuda.is_available()) else False ) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 6c3d1b297..fdd41a22c 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -21,9 +21,10 @@ message SidRqkmeans { optional google.protobuf.Struct faiss_kmeans_kwargs = 5; // Target number of embeddings to reservoir-sample for the FAISS fit // (global, across all ranks). Bounds host memory regardless of corpus - // size. 0 (the default) auto-derives it as K * max_points_per_centroid - // — exactly what FAISS subsamples to internally (default 256), so no - // training points are wasted. + // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid + // (the largest per-layer codebook, for non-uniform codebooks) — exactly + // what FAISS subsamples to internally (default 256), so no training points + // are wasted. optional uint32 train_sample_size = 6 [default = 0]; // Name of the item embedding feature inside the input Batch. From 2bb5abc117692678bea701c6e94cd26fa771cfe8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:02:49 +0000 Subject: [PATCH 044/129] [review] SID: default FAISS fit to CPU + DDP fit-failure test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The earlier gpu=1 "fix" was itself wrong and broke the GPU unittest_ci (cpu_ci/h20 passed): faiss reads `gpu` as a COUNT and 1 == True collapses to all-GPUs, so the rank0-only fit sharded over every rank's device and the GPU faiss path (newly activated — it was a silent CPU fallback before) failed on the tiny test data. faiss's count kwarg cannot pin to a single device, so default the fit to CPU (a bounded one-shot; set gpu in faiss_kmeans_kwargs to opt in explicitly). Also: - _init_reservoir docstring: note the cap targets train_sample_size when set, else default_fit_sample_size(). - Add test_on_train_end_ddp_rank0_failure: forces rank0's fit to raise and asserts every rank raises the coordinated RuntimeError, with join(timeout) so a reintroduced deadlock fails CI instead of hanging. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 +- tzrec/models/sid_rqkmeans_test.py | 64 +++++++++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 15 ++--- 3 files changed, 74 insertions(+), 10 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 384c049a9..66f87cdd2 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,9 +97,10 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: target the points the FAISS fit will subsample to + Per-rank cap: target ``train_sample_size`` when set (>0), else the + points the FAISS fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``), split across - ranks, rather than buffer the whole corpus. + ranks — rather than buffer the whole corpus. """ target = self._model_config.train_sample_size global_target = ( diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 2cfad30a1..e54e44f53 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -426,6 +426,41 @@ def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: dist.destroy_process_group() +def _on_train_end_fail_worker(rank: int, world_size: int, port: int) -> None: + """Worker that forces rank0's FAISS fit to fail. + + Every rank must then raise the coordinated ``RuntimeError`` (driven by the + fit-status broadcast) instead of deadlocking on the centroid broadcast. A + worker returns 0 only if it caught that expected error. + """ + device = _init_dist(rank, world_size, port) + input_dim, n_layers, k = 16, 2, 16 + model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) + model.train() + for _ in range(6): + model.predict(_make_batch(32, input_dim, device)) + + # Force the rank0-only fit to raise (no faiss needed: only rank0 fits, and + # we replace its fit). The status flag must turn this into an all-ranks + # raise, not a hang. + if rank == 0: + + def _boom(*args, **kwargs): + raise RuntimeError("forced rank0 fit failure") + + model._quantizer.train_offline = _boom + + try: + model.on_train_end() + except RuntimeError: + dist.destroy_process_group() + return # expected: coordinated failure reached this rank + dist.destroy_process_group() + raise AssertionError( + f"rank{rank}: on_train_end did not raise on a rank0 fit failure" + ) + + class SidRqkmeansDistTest(unittest.TestCase): """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" @@ -442,6 +477,35 @@ def test_on_train_end_ddp(self) -> None: if p.exitcode != 0: raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + def test_on_train_end_ddp_rank0_failure(self) -> None: + """A rank0-only fit failure raises on every rank — never deadlocks. + + Guards the status-flag-before-centroid-broadcast ordering: a regression + that reordered/dropped it would hang here. ``join(timeout=...)`` turns a + reintroduced deadlock into a CI failure instead of a hung job. + """ + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process( + target=_on_train_end_fail_worker, args=(rank, WORLD_SIZE, port) + ) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join(timeout=120) + if p.is_alive(): + p.terminate() + raise RuntimeError( + f"worker-{i} deadlocked on a rank0 fit failure (timed out)." + ) + if p.exitcode != 0: + raise RuntimeError( + f"worker-{i} did not raise the coordinated error " + f"(exitcode={p.exitcode})." + ) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 83e095192..d3ebf82c2 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -184,15 +184,14 @@ def train_offline( # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Use FAISS GPU compute when a faiss-gpu build is present; an explicit - # ``gpu`` in faiss_kmeans_kwargs always wins. NB faiss reads ``gpu`` as a - # GPU *count* (1 = one GPU = the current/rank0 device), not a device - # index — passing an index of 0 is falsy and silently falls back to CPU. + # Default to a CPU fit. faiss reads ``gpu`` as a GPU *count*, not a + # device index (and ``1 == True`` collapses to all GPUs), so it cannot + # pin this rank0-only fit to a single device without sharding faiss + # memory onto the other ranks' GPUs. The fit is a bounded one-shot over + # the reservoir subsample, so CPU is cheap; set ``gpu`` explicitly in + # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. kwargs = dict(self.faiss_kmeans_kwargs) - if "gpu" not in kwargs: - kwargs["gpu"] = ( - 1 if (faiss.get_num_gpus() > 0 and torch.cuda.is_available()) else False - ) + kwargs.setdefault("gpu", False) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 From 33acbe6caad213a603acd544fef82928956be55b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:16:00 +0000 Subject: [PATCH 045/129] [review] SID: log the FAISS fit device (CPU/GPU) Announce CPU vs GPU + N/D at the start of train_offline so the CPU default isn't silent (configs that don't set faiss_kmeans_kwargs.gpu now fit on CPU). Gated by verbose (on_train_end passes verbose=True). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index d3ebf82c2..21af2f9af 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -192,6 +192,15 @@ def train_offline( # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. kwargs = dict(self.faiss_kmeans_kwargs) kwargs.setdefault("gpu", False) + if verbose: + logger.info( + "[ResidualKMeansQuantizer] fitting %d-layer codebook on %s " + "(N=%d, D=%d); set faiss_kmeans_kwargs.gpu to change.", + self.n_layers, + "GPU" if kwargs["gpu"] else "CPU", + N, + self.embed_dim, + ) # Chunk index.search to cap peak memory (~1 GB at 500K × 512 × 4B). SEARCH_CHUNK = 500_000 From 23c552cf06981f730aa13b1174410d50ddeb79e7 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 06:17:31 +0000 Subject: [PATCH 046/129] [chore] bump version to 1.2.18 Merge upstream/master (1.2.17, incl. #540 DlrmHSTU fix) and bump. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index c0c16b619..52c53fa4f 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.17" +__version__ = "1.2.18" From 3261c2ce04dc36ef0e2d462c063f7532cd9c19d8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 07:41:33 +0000 Subject: [PATCH 047/129] [review] SID: address 23c552c review (test timeout, N>=K assert, cap test, doc) - test_on_train_end_ddp: route both DDP tests through a shared _run_dist_workers(... timeout=120) so a success-path deadlock (e.g. a dropped barrier) fails CI instead of hanging. (#1) - train_offline: assert N >= max(n_embed_list) so a too-small corpus fails loudly instead of faiss silently fitting a degenerate codebook. (#2) - add test_sample_cap_from_train_sample_size covering the explicit train_sample_size branch + per-rank ceil-div across world_size. (#3) - update_metric docstring: note mse/rel_loss are meaningful only with normalize_residuals=False. (#4) Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 7 ++ tzrec/models/sid_rqkmeans_test.py | 96 ++++++++++++------- .../modules/sid/residual_kmeans_quantizer.py | 7 ++ 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 66f87cdd2..477580e4b 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -241,6 +241,13 @@ def update_metric( ) -> None: """Update metric state. + Note: ``mse``/``rel_loss`` compare ``input_embedding`` against the + centroid-sum reconstruction. They are meaningful reconstruction + metrics only with ``normalize_residuals=False`` (the default); with + normalization the centroids live on the rescaled-residual scale, so + the two quantities don't share a scale (same caveat the train_offline + per-layer log carries). + Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index e54e44f53..76a1bda0e 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -43,7 +43,12 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: def _build_model( - input_dim=32, n_layers=2, niter=5, codebook=None, normalize_residuals=False + input_dim=32, + n_layers=2, + niter=5, + codebook=None, + normalize_residuals=False, + train_sample_size=0, ) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. @@ -63,6 +68,7 @@ def _build_model( normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, embedding_feature_name="item_emb", + train_sample_size=train_sample_size, ) return SidRqkmeans( model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), @@ -81,9 +87,17 @@ def _create_model( niter=5, codebook=None, normalize_residuals=False, + train_sample_size=0, ): """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model(input_dim, n_layers, niter, codebook, normalize_residuals) + model = _build_model( + input_dim, + n_layers, + niter, + codebook, + normalize_residuals, + train_sample_size, + ) init_parameters(model, device=torch.device("cpu")) return model @@ -96,6 +110,23 @@ def test_proto_parse(self) -> None: self.assertEqual(model._n_seen, 0) self.assertIsNone(model._reservoir) + def test_sample_cap_from_train_sample_size(self) -> None: + """Explicit train_sample_size drives the per-rank cap (ceil-div).""" + from unittest import mock + + # Single process (world_size=1): cap == train_sample_size. + model = self._create_model(train_sample_size=900) + self.assertEqual(model._sample_cap, 900) + + # Per-rank ceil-div across world_size (patch dist + recompute the cap). + for world_size, expected in [(4, 225), (7, 129), (1000, 1)]: + with ( + mock.patch.object(dist, "is_initialized", return_value=True), + mock.patch.object(dist, "get_world_size", return_value=world_size), + ): + model._init_reservoir() + self.assertEqual(model._sample_cap, expected) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 @@ -461,50 +492,43 @@ def _boom(*args, **kwargs): ) +def _run_dist_workers(worker, world_size: int, timeout: int = 120) -> None: + """Spawn ``world_size`` procs running ``worker(rank, world_size, port)``. + + Joins with a timeout so a deadlock (e.g. a dropped barrier / reordered + broadcast) fails the test instead of hanging CI, and raises on a hung or + nonzero-exit worker. + """ + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(world_size): + p = ctx.Process(target=worker, args=(rank, world_size, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join(timeout=timeout) + if p.is_alive(): + p.terminate() + raise RuntimeError(f"worker-{i} deadlocked (timed out after {timeout}s).") + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + class SidRqkmeansDistTest(unittest.TestCase): """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" def test_on_train_end_ddp(self) -> None: - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process(target=_on_train_end_worker, args=(rank, WORLD_SIZE, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join() - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + _run_dist_workers(_on_train_end_worker, WORLD_SIZE) def test_on_train_end_ddp_rank0_failure(self) -> None: """A rank0-only fit failure raises on every rank — never deadlocks. Guards the status-flag-before-centroid-broadcast ordering: a regression - that reordered/dropped it would hang here. ``join(timeout=...)`` turns a - reintroduced deadlock into a CI failure instead of a hung job. + that reordered/dropped it would hang, which the join timeout turns into + a CI failure instead of a hung job. """ - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process( - target=_on_train_end_fail_worker, args=(rank, WORLD_SIZE, port) - ) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join(timeout=120) - if p.is_alive(): - p.terminate() - raise RuntimeError( - f"worker-{i} deadlocked on a rank0 fit failure (timed out)." - ) - if p.exitcode != 0: - raise RuntimeError( - f"worker-{i} did not raise the coordinated error " - f"(exitcode={p.exitcode})." - ) + _run_dist_workers(_on_train_end_fail_worker, WORLD_SIZE) if __name__ == "__main__": diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 21af2f9af..a2648d2b8 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -178,6 +178,13 @@ def train_offline( ) x = inputs.detach().to("cpu", torch.float32).contiguous().clone() N = x.shape[0] + # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not + # errors) when N < K and returns a degenerate codebook, which the + # all-zero poison guard in KMeansLayer would not catch. + max_k = max(self.n_embed_list) + assert N >= max_k, ( + f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" + ) out = torch.zeros_like(x) # Original input, kept only for the log: the per-layer diagnostic is the # cumulative recon error of x0 by the centroid sum (what update_metric From 39017abf87849b6863d4b02b6bfd41f935a3d3ef Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 08:20:49 +0000 Subject: [PATCH 048/129] [review] checkpoint_util: force only overrides the dedupe Drop the redundant `or force` in `want` (the only caller pairs force with final=True, so final already sets want). `force` now purely bypasses the per-step dedupe, matching its docstring; behavior is identical for the on_train_end tail-save caller. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 5bafd825c..78555a550 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -419,16 +419,17 @@ def maybe_save( data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. final: force a save (still subject to the dedupe), e.g. at train end. - force: save even if this step was already saved (bypasses the - per-step dedupe), e.g. when end-of-train work mutated the model - state at the final step (see ``on_train_end``). + force: when a save is already requested (e.g. ``final``), bypass the + per-step dedupe so it fires even if this step was already saved + — e.g. when end-of-train work mutated the model state at the + final step (see ``on_train_end``). No effect on its own. Returns: True if a checkpoint was saved. """ data_ts = self._reconcile_event_time(data_timestamp) - want = final or force + want = final if self._save_steps > 0 and step > 0 and step % self._save_steps == 0: want = True if ( From 5afbd5ed23437f0360900f74e6f99a3326f1576a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 08:42:07 +0000 Subject: [PATCH 049/129] [review] checkpoint maybe_save: clarify final vs force docstrings Reword the `final`/`force` param docs to remove the verbal collision (both previously described as "force a save"). `final` sets `want`; `force` only relaxes the per-step dedupe and is a no-op on its own. Docstring-only; no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 78555a550..612cf023d 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -418,11 +418,15 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: force a save (still subject to the dedupe), e.g. at train end. - force: when a save is already requested (e.g. ``final``), bypass the - per-step dedupe so it fires even if this step was already saved - — e.g. when end-of-train work mutated the model state at the - final step (see ``on_train_end``). No effect on its own. + final: request a save unconditionally (still subject to the dedupe), + e.g. at train end. This sets ``want``; it does not bypass the + per-step dedupe — that is what ``force`` is for. + force: bypass the per-step dedupe so a wanted save fires even if this + step was already saved — e.g. when end-of-train work mutated the + model state at the already-saved final step (see ``on_train_end``). + Orthogonal to ``final``: ``force`` only relaxes the dedupe and has + no effect on its own (it still needs ``want``, which ``final`` or a + cadence trigger supplies). Returns: True if a checkpoint was saved. From 415b8a38dcac74d428abee3e8045498e59c51699 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:01:39 +0000 Subject: [PATCH 050/129] [refactor] SidRqkmeans: single-process only; raise under DDP Drop the DDP path in on_train_end (gather_object -> rank0 FAISS fit -> status/centroid broadcast -> barrier). SidRqkmeans now supports single-process training only: on_train_end raises RuntimeError when world_size > 1, and fits the codebook on the local reservoir otherwise. Simplify _init_reservoir accordingly (no per-rank cap split). Replace the multi-process DDP tests (gather/broadcast/rank0-failure) with a guard test asserting on_train_end raises under world_size>1; trim now-unused imports. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 99 +++++------------- tzrec/models/sid_rqkmeans_test.py | 162 ++++-------------------------- 2 files changed, 44 insertions(+), 217 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 477580e4b..8375ff981 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -97,18 +97,15 @@ def __init__( def _init_reservoir(self) -> None: """Set up the bounded host reservoir for the end-of-loop FAISS fit. - Per-rank cap: target ``train_sample_size`` when set (>0), else the - points the FAISS fit subsamples to - (``ResidualKMeansQuantizer.default_fit_sample_size``), split across - ranks — rather than buffer the whole corpus. + Caps at ``train_sample_size`` when set (>0), else the points the FAISS + fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) + — rather than buffer the whole corpus. Single-process only (see the + world_size guard in :meth:`on_train_end`), so no per-rank split. """ target = self._model_config.train_sample_size - global_target = ( - target if target > 0 else self._quantizer.default_fit_sample_size() + self._sample_cap = max( + 1, target if target > 0 else self._quantizer.default_fit_sample_size() ) - world_size = dist.get_world_size() if dist.is_initialized() else 1 - # ceil div: round up so the per-rank caps together cover global_target. - self._sample_cap = max(1, -(-global_target // world_size)) # Allocated lazily on the first batch. _n_filled = used slots; # _n_seen = running count for the accept prob. @@ -272,79 +269,35 @@ def on_train_end(self) -> bool: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by - ``tzrec.main``). DDP: every rank gather_objects its reservoir to rank0, - which fits and broadcasts the centroids back. + ``tzrec.main``). Single-process only: the fit runs on one process over + its local reservoir, with no cross-rank gather/broadcast. - An empty reservoir only happens for a pathologically tiny corpus - (rebalance splits rows across ``num_workers * world_size``); it then - fails fast via the fit-status broadcast rather than hanging. + An empty reservoir only happens for a pathologically tiny corpus; the + fit is then skipped and ``False`` returned. Returns: is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint). Only the - single-process path can return ``False`` (empty reservoir, fit - skipped); the DDP path either returns ``True`` or raises (an empty - gather makes rank0's fit fail, which the status broadcast turns - into a coordinated ``RuntimeError``). + (centroids changed → force a final checkpoint), ``False`` if the + fit was skipped (empty reservoir). + + Raises: + RuntimeError: if launched under distributed training + (``world_size > 1``). SidRqkmeans is single-process only. """ - is_ddp = ( - dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 - ) + if ( + dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ): + raise RuntimeError( + "SidRqkmeans supports single-process training only " + f"(world_size=1); got world_size={dist.get_world_size()}. " + "Launch with --nproc-per-node=1." + ) local = self._reservoir_sample() self._reset_reservoir() - if is_ddp: - # Each rank ships its (capped) reservoir to rank0, which fits. - rank = dist.get_rank() - gathered: Optional[List[Optional[torch.Tensor]]] = ( - [None] * dist.get_world_size() if rank == 0 else None - ) - dist.gather_object(local, gathered, dst=0) - del local - fit_ok = True - if rank == 0: - assert gathered is not None - try: - full = torch.cat([g for g in gathered if g is not None], dim=0) - del gathered - logger.info( - "[SidRqkmeans.on_train_end] rank0 fitting FAISS " - "on %d samples (D=%d)." % (full.shape[0], full.shape[1]) - ) - self._quantizer.train_offline(full, verbose=True) - del full - except Exception: # noqa: BLE001 - # Don't raise yet — peers would hang on the broadcast below. - # Signal failure via the status flag so all ranks raise. - # logger.exception keeps the traceback so the rank0-only - # failure is diagnosable from the log. - fit_ok = False - logger.exception( - "[SidRqkmeans.on_train_end] rank0 FAISS fit failed" - ) - # Broadcast rank0's status (int, not bool — see NCCL note below) so - # a rank0-only failure makes all ranks raise instead of deadlocking. - status = torch.tensor( - [1 if fit_ok else 0], - device=self._quantizer.layers[0].centroids.device, - ) - dist.broadcast(status, src=0) - if int(status.item()) == 0: - raise RuntimeError( - "[SidRqkmeans.on_train_end] FAISS fit failed on rank0; " - "see rank0 logs for the underlying error." - ) - # Broadcast centroids; set the init flag locally (avoids - # broadcasting a bool buffer — NCCL bool support is inconsistent). - # All ranks are in lockstep, so a local mark_initialized_() agrees. - for layer in self._quantizer.layers: - dist.broadcast(layer.centroids, src=0) - layer.mark_initialized_() - dist.barrier() - return True - - # Single-process: guard an empty reservoir with a plain local check. if local.shape[0] == 0: logger.warning( "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 76a1bda0e..00a320046 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -9,23 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest import torch import torch.distributed as dist -import torch.multiprocessing as mp from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.models.sid_rqkmeans import SidRqkmeans from tzrec.protos import model_pb2 from tzrec.protos.models import sid_model_pb2 -from tzrec.utils import misc_util from tzrec.utils.state_dict_util import init_parameters -WORLD_SIZE = 2 - def _batch_from_rows(rows: torch.Tensor) -> Batch: """Wrap explicit ``item_emb`` rows in a minimal Batch.""" @@ -52,8 +47,6 @@ def _build_model( ) -> SidRqkmeans: """Build a SidRqkmeans configured for offline FAISS fit. - Module-level (not a method) so the spawned DDP workers below can build - the same model; callers move it to a device / init params as needed. SID models read the item-embedding dense feature directly from the batch and do not consume feature_groups, so none is set. """ @@ -111,21 +104,14 @@ def test_proto_parse(self) -> None: self.assertIsNone(model._reservoir) def test_sample_cap_from_train_sample_size(self) -> None: - """Explicit train_sample_size drives the per-rank cap (ceil-div).""" - from unittest import mock - - # Single process (world_size=1): cap == train_sample_size. + """train_sample_size (when set) drives the reservoir cap directly.""" + # Explicit train_sample_size: cap == train_sample_size. model = self._create_model(train_sample_size=900) self.assertEqual(model._sample_cap, 900) - # Per-rank ceil-div across world_size (patch dist + recompute the cap). - for world_size, expected in [(4, 225), (7, 129), (1000, 1)]: - with ( - mock.patch.object(dist, "is_initialized", return_value=True), - mock.patch.object(dist, "get_world_size", return_value=world_size), - ): - model._init_reservoir() - self.assertEqual(model._sample_cap, expected) + # Default (train_sample_size=0): cap == the FAISS fit's subsample size. + model = self._create_model() + self.assertEqual(model._sample_cap, model._quantizer.default_fit_sample_size()) def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" @@ -352,6 +338,19 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: # No fit happened, so no tail checkpoint is requested. self.assertFalse(model.on_train_end()) # should not raise + def test_on_train_end_raises_under_ddp(self) -> None: + """SidRqkmeans is single-process only: world_size>1 must raise.""" + from unittest import mock + + model = self._create_model() + with ( + mock.patch.object(dist, "is_available", return_value=True), + mock.patch.object(dist, "is_initialized", return_value=True), + mock.patch.object(dist, "get_world_size", return_value=2), + self.assertRaisesRegex(RuntimeError, "single-process"), + ): + model.on_train_end() + def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. @@ -406,130 +405,5 @@ def test_mid_fit_checkpoint_rejected_on_load(self) -> None: fresh.load_state_dict(sd) -# -------------------------------------------------------------------------- -# Distributed (multi-process) test for the DDP on_train_end path: the -# cross-rank gather_object -> FAISS fit -> broadcast sequence the in-process -# tests above cannot reach. NCCL on GPU when >=2 devices, else gloo/CPU. -# -------------------------------------------------------------------------- -def _init_dist(rank: int, world_size: int, port: int) -> torch.device: - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size - if use_cuda: - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return torch.device(f"cuda:{rank}") - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - return torch.device("cpu") - - -def _on_train_end_worker(rank: int, world_size: int, port: int) -> None: - device = _init_dist(rank, world_size, port) - input_dim, n_layers, k = 16, 2, 16 - model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) - model.train() - - torch.manual_seed(100 + rank) - for _ in range(6): - model.predict(_make_batch(32, input_dim, device)) - assert model._n_seen == 6 * 32, f"rank{rank}: reservoir not filled" - - # gather_object -> rank0 FAISS fit -> broadcast centroids + fill flag. - # Every rank fitted/received the codebook, so each requests a tail ckpt. - assert model.on_train_end(), f"rank{rank}: on_train_end should request ckpt" - - for layer in model._quantizer.layers: - assert bool(layer._is_initialized.item()), f"rank{rank}: layer uninit" - assert layer.centroids.abs().sum().item() > 0.0, f"rank{rank}: zero centroids" - # Centroids were broadcast from rank0 -> must be bit-identical across ranks. - for layer in model._quantizer.layers: - cmin, cmax = layer.centroids.clone(), layer.centroids.clone() - dist.all_reduce(cmin, op=dist.ReduceOp.MIN) - dist.all_reduce(cmax, op=dist.ReduceOp.MAX) - assert torch.allclose(cmin, cmax), f"rank{rank}: centroids differ across ranks" - - model.eval() - codes = model.predict(_make_batch(8, input_dim, device))["codes"] - assert codes.shape == (8, n_layers), f"rank{rank}: bad codes shape {codes.shape}" - assert (codes >= 0).all() and (codes < k).all(), f"rank{rank}: codes out of range" - dist.destroy_process_group() - - -def _on_train_end_fail_worker(rank: int, world_size: int, port: int) -> None: - """Worker that forces rank0's FAISS fit to fail. - - Every rank must then raise the coordinated ``RuntimeError`` (driven by the - fit-status broadcast) instead of deadlocking on the centroid broadcast. A - worker returns 0 only if it caught that expected error. - """ - device = _init_dist(rank, world_size, port) - input_dim, n_layers, k = 16, 2, 16 - model = _build_model(input_dim, n_layers, codebook=[k] * n_layers).to(device) - model.train() - for _ in range(6): - model.predict(_make_batch(32, input_dim, device)) - - # Force the rank0-only fit to raise (no faiss needed: only rank0 fits, and - # we replace its fit). The status flag must turn this into an all-ranks - # raise, not a hang. - if rank == 0: - - def _boom(*args, **kwargs): - raise RuntimeError("forced rank0 fit failure") - - model._quantizer.train_offline = _boom - - try: - model.on_train_end() - except RuntimeError: - dist.destroy_process_group() - return # expected: coordinated failure reached this rank - dist.destroy_process_group() - raise AssertionError( - f"rank{rank}: on_train_end did not raise on a rank0 fit failure" - ) - - -def _run_dist_workers(worker, world_size: int, timeout: int = 120) -> None: - """Spawn ``world_size`` procs running ``worker(rank, world_size, port)``. - - Joins with a timeout so a deadlock (e.g. a dropped barrier / reordered - broadcast) fails the test instead of hanging CI, and raises on a hung or - nonzero-exit worker. - """ - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(world_size): - p = ctx.Process(target=worker, args=(rank, world_size, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join(timeout=timeout) - if p.is_alive(): - p.terminate() - raise RuntimeError(f"worker-{i} deadlocked (timed out after {timeout}s).") - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") - - -class SidRqkmeansDistTest(unittest.TestCase): - """2-rank test for SidRqkmeans.on_train_end (gather -> fit -> broadcast).""" - - def test_on_train_end_ddp(self) -> None: - _run_dist_workers(_on_train_end_worker, WORLD_SIZE) - - def test_on_train_end_ddp_rank0_failure(self) -> None: - """A rank0-only fit failure raises on every rank — never deadlocks. - - Guards the status-flag-before-centroid-broadcast ordering: a regression - that reordered/dropped it would hang, which the join timeout turns into - a CI failure instead of a hung job. - """ - _run_dist_workers(_on_train_end_fail_worker, WORLD_SIZE) - - if __name__ == "__main__": unittest.main() From b27eb7b5b55e7dab7a7d8b014feb51df7567ca73 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:05:27 +0000 Subject: [PATCH 051/129] [refactor] SidRqkmeans: move DDP guard to __init__ (fail fast) Raise the single-process world_size>1 guard at construction instead of in on_train_end, so an accidental multi-rank launch fails immediately rather than after a full training pass. Update the guard test to assert __init__ raises. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 36 +++++++++++++++---------------- tzrec/models/sid_rqkmeans_test.py | 7 +++--- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 8375ff981..3c87cedae 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,6 +72,20 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) + # Single-process only: the FAISS fit runs on one process over its local + # reservoir, with no cross-rank gather/broadcast. Fail fast here rather + # than after a full (wasted) training pass. + if ( + dist.is_available() + and dist.is_initialized() + and dist.get_world_size() > 1 + ): + raise RuntimeError( + "SidRqkmeans supports single-process training only " + f"(world_size=1); got world_size={dist.get_world_size()}. " + "Launch with --nproc-per-node=1." + ) + cfg = self._model_config # SidRqkmeans proto message # config_to_kwargs yields Struct numbers as floats; coerce back to int. @@ -100,7 +114,7 @@ def _init_reservoir(self) -> None: Caps at ``train_sample_size`` when set (>0), else the points the FAISS fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) — rather than buffer the whole corpus. Single-process only (see the - world_size guard in :meth:`on_train_end`), so no per-rank split. + world_size guard in ``__init__``), so no per-rank split. """ target = self._model_config.train_sample_size self._sample_cap = max( @@ -269,8 +283,9 @@ def on_train_end(self) -> bool: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by - ``tzrec.main``). Single-process only: the fit runs on one process over - its local reservoir, with no cross-rank gather/broadcast. + ``tzrec.main``). Single-process only (enforced by the world_size guard + in ``__init__``): the fit runs on one process over its local reservoir, + with no cross-rank gather/broadcast. An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped and ``False`` returned. @@ -279,22 +294,7 @@ def on_train_end(self) -> bool: is_ckpt_after_train (bool): ``True`` if the codebook was fitted (centroids changed → force a final checkpoint), ``False`` if the fit was skipped (empty reservoir). - - Raises: - RuntimeError: if launched under distributed training - (``world_size > 1``). SidRqkmeans is single-process only. """ - if ( - dist.is_available() - and dist.is_initialized() - and dist.get_world_size() > 1 - ): - raise RuntimeError( - "SidRqkmeans supports single-process training only " - f"(world_size=1); got world_size={dist.get_world_size()}. " - "Launch with --nproc-per-node=1." - ) - local = self._reservoir_sample() self._reset_reservoir() diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 00a320046..ef90d6032 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -338,18 +338,17 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: # No fit happened, so no tail checkpoint is requested. self.assertFalse(model.on_train_end()) # should not raise - def test_on_train_end_raises_under_ddp(self) -> None: - """SidRqkmeans is single-process only: world_size>1 must raise.""" + def test_init_raises_under_ddp(self) -> None: + """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" from unittest import mock - model = self._create_model() with ( mock.patch.object(dist, "is_available", return_value=True), mock.patch.object(dist, "is_initialized", return_value=True), mock.patch.object(dist, "get_world_size", return_value=2), self.assertRaisesRegex(RuntimeError, "single-process"), ): - model.on_train_end() + self._create_model() def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. From 6f7ae1dee07ad6ba6d94239e3a6519d2a08b4a35 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:08:47 +0000 Subject: [PATCH 052/129] [simplify] SidRqkmeans: drop dead max(1,...) cap clamp; fold test _build_model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _init_reservoir: both cap branches are always >= 1 now that the per-rank ceil-div is gone, so the max(1, ...) clamp is dead — drop it. Test: _build_model was module-level only to serve the (now-deleted) DDP worker processes; fold it into _create_model, its sole remaining caller. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 4 +- tzrec/models/sid_rqkmeans_test.py | 63 +++++++++++-------------------- 2 files changed, 24 insertions(+), 43 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3c87cedae..65c5aab46 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -117,8 +117,8 @@ def _init_reservoir(self) -> None: world_size guard in ``__init__``), so no per-rank split. """ target = self._model_config.train_sample_size - self._sample_cap = max( - 1, target if target > 0 else self._quantizer.default_fit_sample_size() + self._sample_cap = ( + target if target > 0 else self._quantizer.default_fit_sample_size() ) # Allocated lazily on the first batch. _n_filled = used slots; diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index ef90d6032..3b7aded5b 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -37,39 +37,6 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: return _batch_from_rows(torch.randn(batch_size, input_dim, device=device)) -def _build_model( - input_dim=32, - n_layers=2, - niter=5, - codebook=None, - normalize_residuals=False, - train_sample_size=0, -) -> SidRqkmeans: - """Build a SidRqkmeans configured for offline FAISS fit. - - SID models read the item-embedding dense feature directly from the batch - and do not consume feature_groups, so none is set. - """ - from google.protobuf.struct_pb2 import Struct - - n_embed_list = codebook if codebook is not None else [16] * n_layers - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) - cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, - codebook=n_embed_list, - normalize_residuals=normalize_residuals, - faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", - train_sample_size=train_sample_size, - ) - return SidRqkmeans( - model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), - features=[], - labels=[], - ) - - class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" @@ -82,14 +49,28 @@ def _create_model( normalize_residuals=False, train_sample_size=0, ): - """Create a SidRqkmeans on CPU with params initialized.""" - model = _build_model( - input_dim, - n_layers, - niter, - codebook, - normalize_residuals, - train_sample_size, + """Build a SidRqkmeans on CPU with params initialized. + + SID models read the item-embedding dense feature directly from the + batch and do not consume feature_groups, so none is set. + """ + from google.protobuf.struct_pb2 import Struct + + n_embed_list = codebook if codebook is not None else [16] * n_layers + faiss_kwargs = Struct() + faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + cfg = sid_model_pb2.SidRqkmeans( + input_dim=input_dim, + codebook=n_embed_list, + normalize_residuals=normalize_residuals, + faiss_kmeans_kwargs=faiss_kwargs, + embedding_feature_name="item_emb", + train_sample_size=train_sample_size, + ) + model = SidRqkmeans( + model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), + features=[], + labels=[], ) init_parameters(model, device=torch.device("cpu")) return model From 5827d5b5b9fee48de07d9ce7a281d048d35b709d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:12:30 +0000 Subject: [PATCH 053/129] [style] ruff-format the __init__ DDP guard (collapse to one line) The world_size guard fit within the line limit, so ruff format collapses the parenthesized multi-line `if` to a single line. ruff check passed but ruff-format (pre-commit / codestyle CI) did not. No logic change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 65c5aab46..12641974a 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -75,11 +75,7 @@ def __init__( # Single-process only: the FAISS fit runs on one process over its local # reservoir, with no cross-rank gather/broadcast. Fail fast here rather # than after a full (wasted) training pass. - if ( - dist.is_available() - and dist.is_initialized() - and dist.get_world_size() > 1 - ): + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: raise RuntimeError( "SidRqkmeans supports single-process training only " f"(world_size=1); got world_size={dist.get_world_size()}. " From 4e2e87848f85112452014a63b6c69271388fecbd Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 12:59:41 +0000 Subject: [PATCH 054/129] =?UTF-8?q?[refactor]=20SidRqkmeans:=20CPU-only=20?= =?UTF-8?q?=E2=80=94=20raise=20on=20visible=20CUDA,=20drop=20device=20copi?= =?UTF-8?q?es?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SID RQ-KMeans is now CPU-only by decision. __init__ raises RuntimeError when torch.cuda.is_available() so all tensors (embeddings, reservoir, FAISS fit) stay on the host; run with CUDA_VISIBLE_DEVICES="". Remove the now-dead CPU<->GPU copies: - _reservoir_add: x is already on host, so .to("cpu", float32) becomes a plain float32 cast; drop idx.to(x.device). - train_offline: input is host float32 (.to("cpu") -> .to(float32)); drop centroids.cpu() and the explicit device="cpu" on search indices. - Drop the faiss-GPU passthrough: pop any "gpu" kwarg so a stale config / faiss-gpu build can't target an absent GPU; log line is CPU-only. Tests: setUp simulates a CPU-only host (GPU CI runners have CUDA, which would otherwise trip the new guard); add test_init_raises_on_gpu. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 21 +++++++--- tzrec/models/sid_rqkmeans_test.py | 19 ++++++++- .../modules/sid/residual_kmeans_quantizer.py | 41 ++++++++----------- 3 files changed, 50 insertions(+), 31 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 12641974a..c393b4943 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,6 +72,16 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) + # CPU-only: everything (embeddings, reservoir, FAISS fit) stays on the + # host, so there are no device copies on the train path. Refuse to run + # when CUDA is visible rather than silently shuttling tensors to/from a + # GPU; launch with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host). + if torch.cuda.is_available(): + raise RuntimeError( + "SidRqkmeans is CPU-only, but a CUDA device is visible. " + 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' + ) + # Single-process only: the FAISS fit runs on one process over its local # reservoir, with no cross-rank gather/broadcast. Fail fast here rather # than after a full (wasted) training pass. @@ -138,11 +148,12 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if self._reservoir is None: self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - # Phase 1: fill empty slots first. Copy only the rows we keep to host. + # Phase 1: fill empty slots first. x is already on the host (CPU-only + # model), so this is a dtype cast into the reservoir, not a device copy. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( - "cpu", dtype=torch.float32 + torch.float32 ) self._n_filled += take self._n_seen += take @@ -151,9 +162,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: return # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random - # slot. The accept decision needs only counts, so compute it on host and - # copy ONLY accepted rows (in steady state, almost none) — avoiding the - # whole-batch GPU->CPU copy. float64 keeps n_seen+j+1 exact past 2**24. + # slot. float64 keeps n_seen+j+1 exact past 2**24. r = x.shape[0] pos = self._n_seen + torch.arange(r) accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) @@ -161,7 +170,7 @@ def _reservoir_add(self, x: torch.Tensor) -> None: if idx.numel() > 0: slots = torch.randint(0, cap, (idx.numel(),)) # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. - self._reservoir[slots] = x[idx.to(x.device)].to("cpu", dtype=torch.float32) + self._reservoir[slots] = x[idx].to(torch.float32) self._n_seen += r def _reservoir_sample(self) -> torch.Tensor: diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 3b7aded5b..f29b9455a 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from unittest import mock import torch import torch.distributed as dist @@ -40,6 +41,14 @@ def _make_batch(batch_size: int, input_dim: int, device: str = "cpu") -> Batch: class SidRqkmeansOfflineTest(unittest.TestCase): """Single-process tests for SidRqkmeans (FAISS-only).""" + def setUp(self) -> None: + # SidRqkmeans is CPU-only and refuses to init when CUDA is visible. The + # GPU CI runners have CUDA, so simulate a CPU-only host for every + # construction-based test. (test_init_raises_on_gpu overrides this.) + patcher = mock.patch.object(torch.cuda, "is_available", return_value=False) + patcher.start() + self.addCleanup(patcher.stop) + def _create_model( self, input_dim=32, @@ -321,8 +330,6 @@ def test_on_train_end_noop_on_empty_buffer(self) -> None: def test_init_raises_under_ddp(self) -> None: """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" - from unittest import mock - with ( mock.patch.object(dist, "is_available", return_value=True), mock.patch.object(dist, "is_initialized", return_value=True), @@ -331,6 +338,14 @@ def test_init_raises_under_ddp(self) -> None: ): self._create_model() + def test_init_raises_on_gpu(self) -> None: + """SidRqkmeans is CPU-only: a visible CUDA device fails fast in init.""" + with ( + mock.patch.object(torch.cuda, "is_available", return_value=True), + self.assertRaisesRegex(RuntimeError, "CPU-only"), + ): + self._create_model() + def test_post_fit_checkpoint_round_trips(self) -> None: """Fit → save state_dict → load into fresh instance → predict. diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index a2648d2b8..971ef9e3b 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -55,7 +55,8 @@ class ResidualKMeansQuantizer(ResidualQuantizer): residuals with no per-layer normalization). faiss_kmeans_kwargs (Dict|None): extra kwargs forwarded to ``faiss.Kmeans(D, K, **kwargs)`` (e.g. {'niter': 20, - 'gpu': True, 'verbose': True, 'spherical': False}). + 'verbose': True, 'spherical': False}). A ``gpu`` key is ignored — + the fit is CPU-only. """ def __init__( @@ -160,23 +161,21 @@ def train_offline( ) -> None: """Train the multi-layer codebook via offline FAISS K-Means. - The residual matrix stays a host (CPU) tensor. With a faiss-gpu build, - ``faiss.Kmeans`` runs the K-Means training (over its internally - subsampled set) on the GPU; the post-fit ``index.search`` assignment - still streams all N rows through in ``SEARCH_CHUNK``-sized chunks, so we - never hold the full (N, D) on the device. faiss-cpu runs the same path - on CPU. + CPU-only: ``inputs`` is already a host tensor (SidRqkmeans refuses to + run when CUDA is visible) and the FAISS fit runs on CPU. The post-fit + ``index.search`` assignment streams all N rows through in + ``SEARCH_CHUNK``-sized chunks to cap peak memory. Args: - inputs (Tensor): embedding matrix (N, D). Copied once to an owned - CPU float32 tensor; not mutated. + inputs (Tensor): embedding matrix (N, D) on CPU. Copied once to an + owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Own a contiguous CPU float32 copy to update in place as the residual. + # Own a contiguous float32 copy to update in place as the residual. assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - x = inputs.detach().to("cpu", torch.float32).contiguous().clone() + x = inputs.detach().to(torch.float32).contiguous().clone() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the @@ -191,20 +190,16 @@ def train_offline( # reports). ``out + x`` would equal it only without normalization. x0 = x.clone() if verbose else None - # Default to a CPU fit. faiss reads ``gpu`` as a GPU *count*, not a - # device index (and ``1 == True`` collapses to all GPUs), so it cannot - # pin this rank0-only fit to a single device without sharding faiss - # memory onto the other ranks' GPUs. The fit is a bounded one-shot over - # the reservoir subsample, so CPU is cheap; set ``gpu`` explicitly in - # faiss_kmeans_kwargs (e.g. ``True`` for all GPUs) to opt into GPU. + # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, + # so the codebook is always built on CPU. Drop any stale ``gpu`` request + # from the config so a faiss-gpu build can't try to use an absent GPU. kwargs = dict(self.faiss_kmeans_kwargs) - kwargs.setdefault("gpu", False) + kwargs.pop("gpu", None) if verbose: logger.info( - "[ResidualKMeansQuantizer] fitting %d-layer codebook on %s " - "(N=%d, D=%d); set faiss_kmeans_kwargs.gpu to change.", + "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " + "(N=%d, D=%d).", self.n_layers, - "GPU" if kwargs["gpu"] else "CPU", N, self.embed_dim, ) @@ -222,12 +217,12 @@ def train_offline( self.embed_dim, self.n_embed_list[layer_idx], **kwargs ) kmeans.train(x) - centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32).cpu() + centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32) for start in range(0, N, SEARCH_CHUNK): end = min(start + SEARCH_CHUNK, N) _, idx = kmeans.index.search(x[start:end], 1) - idx = torch.as_tensor(idx, device="cpu").reshape(-1).long() + idx = torch.as_tensor(idx).reshape(-1).long() q = centroids[idx] # (chunk, D) out[start:end] += q x[start:end] -= q # residual From 4773e2a657ac4746bac8d80c05deafc0d8df578f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 9 Jun 2026 13:03:31 +0000 Subject: [PATCH 055/129] [simplify] train_offline: assert host input; single-copy float32 own MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Assert inputs is not CUDA: the quantizer is a standalone module that now assumes host tensors (SidRqkmeans enforces CPU-only at __init__); make the contract local so misuse fails here, not opaquely inside faiss. - Replace `.to(float32).contiguous().clone()` with `.to(dtype=float32, copy=True).contiguous()` — one guaranteed owning copy instead of a chain that could double-copy on a non-contiguous/non-float32 input. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_kmeans_quantizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 971ef9e3b..1bfe20267 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -171,11 +171,15 @@ def train_offline( owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Own a contiguous float32 copy to update in place as the residual. + # CPU-only: SidRqkmeans refuses to init when CUDA is visible, but this + # quantizer is a standalone module — assert the host-tensor contract it + # relies on so misuse fails here, not deep inside faiss. + assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - x = inputs.detach().to(torch.float32).contiguous().clone() + # Own one contiguous float32 copy to update in place as the residual. + x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the From df83d070f4569c0dfa2e14543451843c790d41d4 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 02:34:01 +0000 Subject: [PATCH 056/129] [refactor] KMeansLayer.predict: use torch.cdist; drop _squared_euclidean_distance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: replace the manual squared-L2 expansion with torch.cdist(batch, centroids).argmin(-1). argmin is invariant to the monotonic sqrt, so codes are identical for all non-degenerate inputs — verified bit-exact across a wide sweep (random shapes/dtypes incl. 1000x8192x128, all cdist compute_modes, large-magnitude/cancellation, and near-ties): 0 mismatches over ~140k rows. The only divergence is at *exact* equidistant ties (measure zero for real embeddings), where either centroid is equally near. Confirmed predict still scripts / FX-traces / torch.exports identically to eager. Removes the now-unused _squared_euclidean_distance helper + its unit test on this branch. NOTE: feat/sid_abstract's vector_quantize.py (RQ-VAE, PR3) also imports this helper — PR3 must migrate its l2 branch to cdist in the same series. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 29 ++++++++--------------------- tzrec/modules/sid/kmeans_test.py | 9 --------- 2 files changed, 8 insertions(+), 30 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index d6e34acd9..7e874a0c1 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -52,25 +52,6 @@ def recon_diagnostics( return mse, rel -@torch.no_grad() -def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Squared L2 distance between rows of ``x`` and ``y``. - - Args: - x (Tensor): data points, shape (N, D). - y (Tensor): centroids, shape (K, D). - - Returns: - Tensor: squared distances, shape (N, K). - - Kept branch-free (no data-dependent control flow on ``N``) so the - per-batch predict forward stays FX-traceable for torchrec inference. - """ - x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) - y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) - - class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. @@ -165,11 +146,17 @@ def _load_from_state_dict( def predict(self, batch: torch.Tensor) -> torch.Tensor: """Assign points to nearest centroid. + Uses ``torch.cdist`` (plain L2). argmin is invariant to the monotonic + sqrt, so the assignment is identical to squared-L2 for all + non-degenerate inputs (verified bit-exact across random / large- + magnitude / near-tie sweeps); only an exact equidistant tie — measure + zero for real embeddings — may resolve to a different, equally-near + centroid. + Args: batch (Tensor): data points, shape (B, D). Returns: Tensor: cluster indices, shape (B,). """ - dists = _squared_euclidean_distance(batch, self.centroids) - return torch.argmin(dists, dim=-1) + return torch.cdist(batch, self.centroids).argmin(dim=-1) diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index cb86a39d8..1b21604d3 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -15,7 +15,6 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, - _squared_euclidean_distance, recon_diagnostics, ) @@ -29,14 +28,6 @@ def test_recon_diagnostics_zero_on_identity(self) -> None: self.assertAlmostEqual(mse.item(), 0.0, places=6) self.assertAlmostEqual(rel.item(), 0.0, places=6) - def test_squared_euclidean_distance(self) -> None: - x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) - y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) - d = _squared_euclidean_distance(x, y) - self.assertEqual(d.shape, (2, 2)) - # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 - torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - class KMeansLayerTest(unittest.TestCase): """Tests for the single KMeansLayer.""" From d037db7dbb6701a4a2d5f59989da170066c5fc20 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 02:56:23 +0000 Subject: [PATCH 057/129] [refactor] SidRqkmeans: drop input_embedding from predictions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review (#1): the reconstruction target is an input, not a model output, so don't thread it through predictions. update_metric now re-extracts the embedding from batch (mirrors SidRqvae.update_metric) and guards on "quantized", which is eval-only — so the reconstruction metric stays eval-only by construction. predict (eval) now exposes {codes, quantized}; the metric test passes the same batch through predict + update_metric so the re-extracted target matches the prediction. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 29 ++++++++++++++--------------- tzrec/models/sid_rqkmeans_test.py | 16 ++++++++++------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index c393b4943..7a065e471 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -218,7 +218,6 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: if self.is_eval: predictions["quantized"] = quantized - predictions["input_embedding"] = embedding return predictions @@ -257,28 +256,28 @@ def update_metric( ) -> None: """Update metric state. - Note: ``mse``/``rel_loss`` compare ``input_embedding`` against the - centroid-sum reconstruction. They are meaningful reconstruction - metrics only with ``normalize_residuals=False`` (the default); with - normalization the centroids live on the rescaled-residual scale, so - the two quantities don't share a scale (same caveat the train_offline - per-layer log carries). + The reconstruction target (the input embedding) is re-extracted from + ``batch`` rather than threaded through ``predictions`` — it is an input, + not a model output (mirrors ``SidRqvae.update_metric``). ``quantized`` is + present only in eval (see ``predict``), so this runs eval-only. + + Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum + reconstruction. They are meaningful reconstruction metrics only with + ``normalize_residuals=False`` (the default); with normalization the + centroids live on the rescaled-residual scale, so the two quantities + don't share a scale (same caveat the train_offline per-layer log carries). Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - if "input_embedding" in predictions: - _, rel = recon_diagnostics( - predictions["input_embedding"], - predictions["quantized"], - ) + if "quantized" in predictions: + embedding = self._extract_feature(batch) + _, rel = recon_diagnostics(embedding, predictions["quantized"]) # mse aggregates (preds, target) itself; rel_loss has no # torchmetrics equivalent, so it stays a MeanMetric. - self._metric_modules["mse"].update( - predictions["quantized"], predictions["input_embedding"] - ) + self._metric_modules["mse"].update(predictions["quantized"], embedding) self._metric_modules["rel_loss"].update(rel) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index f29b9455a..f0964a8d0 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -275,7 +275,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval exposes quantized/input_embedding; inference is codes-only.""" + """Eval exposes codes + quantized only; inference is codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -288,11 +288,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode: reconstruction outputs are present for update_metric. + # Eval mode: the centroid-sum reconstruction is exposed for + # update_metric; the input embedding is NOT threaded through + # predictions (it is re-extracted from the batch in update_metric). model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertIn("quantized", eval_preds) - self.assertIn("input_embedding", eval_preds) + self.assertEqual(set(eval_preds.keys()), {"codes", "quantized"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) @@ -315,8 +316,11 @@ def test_eval_metric_path(self) -> None: model.init_metric() model.eval() - preds = model.predict(_make_batch(B, input_dim)) - model.update_metric(preds, _make_batch(B, input_dim)) + # Same batch through predict + update_metric: the reconstruction target + # is re-extracted from this batch, so it must match the predicted one. + batch = _make_batch(B, input_dim) + preds = model.predict(batch) + model.update_metric(preds, batch) metrics = model.compute_metric() for key in ("mse", "rel_loss", "unique_sid_ratio"): self.assertIn(key, metrics) From 88856f3adf44e809fcf1d9d24fc75d761b73703b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:08:18 +0000 Subject: [PATCH 058/129] [simplify] trim SID docstrings (predict provenance; stale SidRqvae xref) - KMeansLayer.predict: collapse the 6-line cdist-vs-squared-L2 verification provenance to a one-line equivalence note. - SidRqkmeans.update_metric: trim the over-explained re-extraction paragraph and drop the ``SidRqvae.update_metric`` cross-ref (sid_rqvae.py is PR3, not present in this PR's merge target, so the symbol doesn't resolve). Docstring-only; no behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 ++--- tzrec/modules/sid/kmeans.py | 9 +++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 7a065e471..35dbc1036 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -257,9 +257,8 @@ def update_metric( """Update metric state. The reconstruction target (the input embedding) is re-extracted from - ``batch`` rather than threaded through ``predictions`` — it is an input, - not a model output (mirrors ``SidRqvae.update_metric``). ``quantized`` is - present only in eval (see ``predict``), so this runs eval-only. + ``batch`` — it is an input, not a model output. ``quantized`` is present + only in eval (see ``predict``), so this runs eval-only. Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum reconstruction. They are meaningful reconstruction metrics only with diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 7e874a0c1..02cfc63d6 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -146,12 +146,9 @@ def _load_from_state_dict( def predict(self, batch: torch.Tensor) -> torch.Tensor: """Assign points to nearest centroid. - Uses ``torch.cdist`` (plain L2). argmin is invariant to the monotonic - sqrt, so the assignment is identical to squared-L2 for all - non-degenerate inputs (verified bit-exact across random / large- - magnitude / near-tie sweeps); only an exact equidistant tie — measure - zero for real embeddings — may resolve to a different, equally-near - centroid. + Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, + so assignments match squared-L2 except at exact equidistant ties + (measure zero for real embeddings), where either centroid is valid. Args: batch (Tensor): data points, shape (B, D). From 2fa312b99da7c115c7d44021e26075d88fb8f7df Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:27:13 +0000 Subject: [PATCH 059/129] [refactor] extract reservoir sampling into ReservoirSampler (kmeans.py) Move the Vitter-Algorithm-R reservoir out of SidRqkmeans into a standalone ReservoirSampler class in kmeans.py (the shared SID/kmeans utility module, already home to recon_diagnostics). SidRqkmeans now holds one ReservoirSampler(cap, dim) and calls add()/sample()/reset(); the four state fields and three private reservoir methods are gone. Reservoir-mechanics tests (caps_memory, phase2_replacement) move to kmeans_test.py against ReservoirSampler directly (no model needed), plus empty-sample and reset tests; model tests now poke the sampler via its n_seen/n_filled/capacity accessors. Behavior-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 88 ++++------------------------- tzrec/models/sid_rqkmeans_test.py | 80 ++++----------------------- tzrec/modules/sid/kmeans.py | 92 ++++++++++++++++++++++++++++++- tzrec/modules/sid/kmeans_test.py | 73 ++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 145 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 35dbc1036..3f3b6954f 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,7 +27,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import recon_diagnostics +from tzrec.modules.sid.kmeans import ReservoirSampler, recon_diagnostics from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -109,82 +109,18 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - self._init_reservoir() + # Bounded host reservoir for the end-of-loop FAISS fit: cap at + # ``train_sample_size`` when set (>0), else the points the FAISS fit + # subsamples to (``default_fit_sample_size``) — rather than buffer the + # whole corpus. Single-process only (see the world_size guard above), + # so no per-rank split. + target = self._model_config.train_sample_size + cap = target if target > 0 else self._quantizer.default_fit_sample_size() + self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. self._dummy_param = nn.Parameter(torch.zeros(1), requires_grad=True) - def _init_reservoir(self) -> None: - """Set up the bounded host reservoir for the end-of-loop FAISS fit. - - Caps at ``train_sample_size`` when set (>0), else the points the FAISS - fit subsamples to (``ResidualKMeansQuantizer.default_fit_sample_size``) - — rather than buffer the whole corpus. Single-process only (see the - world_size guard in ``__init__``), so no per-rank split. - """ - target = self._model_config.train_sample_size - self._sample_cap = ( - target if target > 0 else self._quantizer.default_fit_sample_size() - ) - - # Allocated lazily on the first batch. _n_filled = used slots; - # _n_seen = running count for the accept prob. - self._reservoir: Optional[torch.Tensor] = None - self._n_filled = 0 - self._n_seen = 0 - - @torch.no_grad() - def _reservoir_add(self, x: torch.Tensor) -> None: - """Stream a batch into the reservoir (Vitter Algorithm R). - - Keeps a uniform ``_sample_cap`` sample of all embeddings seen, in - O(cap) host memory. - - Args: - x (Tensor): batch of embeddings, shape (B, D). - """ - x = x.detach() - cap = self._sample_cap - if self._reservoir is None: - self._reservoir = torch.empty(cap, x.shape[1], dtype=torch.float32) - - # Phase 1: fill empty slots first. x is already on the host (CPU-only - # model), so this is a dtype cast into the reservoir, not a device copy. - if self._n_filled < cap: - take = min(x.shape[0], cap - self._n_filled) - self._reservoir[self._n_filled : self._n_filled + take] = x[:take].to( - torch.float32 - ) - self._n_filled += take - self._n_seen += take - x = x[take:] - if x.shape[0] == 0: - return - - # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random - # slot. float64 keeps n_seen+j+1 exact past 2**24. - r = x.shape[0] - pos = self._n_seen + torch.arange(r) - accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) - idx = accept.nonzero(as_tuple=True)[0] - if idx.numel() > 0: - slots = torch.randint(0, cap, (idx.numel(),)) - # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. - self._reservoir[slots] = x[idx].to(torch.float32) - self._n_seen += r - - def _reservoir_sample(self) -> torch.Tensor: - """Return the filled portion of the reservoir, shape (n_filled, D).""" - if self._reservoir is None or self._n_filled == 0: - return torch.empty(0, self._input_dim, dtype=torch.float32) - return self._reservoir[: self._n_filled] - - def _reset_reservoir(self) -> None: - """Drop the reservoir after the FAISS fit to free host memory.""" - self._reservoir = None - self._n_filled = 0 - self._n_seen = 0 - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. @@ -202,7 +138,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: # Training: just reservoir-sample for the end-of-loop FAISS fit and # return dummy codes — the codebook does not exist yet. if self.is_train: - self._reservoir_add(embedding) + self._reservoir.add(embedding) B = embedding.shape[0] return { "codes": torch.zeros( @@ -298,8 +234,8 @@ def on_train_end(self) -> bool: (centroids changed → force a final checkpoint), ``False`` if the fit was skipped (empty reservoir). """ - local = self._reservoir_sample() - self._reset_reservoir() + local = self._reservoir.sample() + self._reservoir.reset() if local.shape[0] == 0: logger.warning( diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index f0964a8d0..c312ebad3 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -90,18 +90,20 @@ def test_proto_parse(self) -> None: self.assertEqual(model._faiss_kwargs.get("niter"), 5) self.assertEqual(model._faiss_kwargs.get("seed"), 1234) self.assertFalse(model._faiss_kwargs.get("verbose")) - self.assertEqual(model._n_seen, 0) - self.assertIsNone(model._reservoir) + self.assertEqual(model._reservoir.n_seen, 0) + self.assertEqual(model._reservoir.n_filled, 0) def test_sample_cap_from_train_sample_size(self) -> None: """train_sample_size (when set) drives the reservoir cap directly.""" # Explicit train_sample_size: cap == train_sample_size. model = self._create_model(train_sample_size=900) - self.assertEqual(model._sample_cap, 900) + self.assertEqual(model._reservoir.capacity, 900) # Default (train_sample_size=0): cap == the FAISS fit's subsample size. model = self._create_model() - self.assertEqual(model._sample_cap, model._quantizer.default_fit_sample_size()) + self.assertEqual( + model._reservoir.capacity, model._quantizer.default_fit_sample_size() + ) def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" @@ -116,70 +118,12 @@ def test_predict_collects_buffer(self) -> None: # Reservoir holds all 4*B samples (well under the cap) and tracks # the running count. - self.assertEqual(model._n_seen, 4 * B) - self.assertEqual(model._n_filled, 4 * B) + self.assertEqual(model._reservoir.n_seen, 4 * B) + self.assertEqual(model._reservoir.n_filled, 4 * B) # FAISS not yet triggered: layers should be uninitialized for layer in model._quantizer.layers: self.assertFalse(layer.is_initialized) - def test_reservoir_caps_memory(self) -> None: - """Reservoir bounds the buffer at _sample_cap regardless of corpus.""" - B, input_dim = 16, 8 - model = self._create_model(input_dim=input_dim) - model._sample_cap = 10 # force a tiny cap - model._reset_reservoir() - model.train() - for _ in range(20): # 320 rows >> cap - model.predict(_make_batch(B, input_dim)) - self.assertEqual(model._n_seen, 20 * B) - self.assertEqual(model._n_filled, 10) - self.assertEqual(model._reservoir.shape, (10, input_dim)) - - def test_reservoir_phase2_replacement(self) -> None: - """Phase-2 replacement keeps a valid reservoir of real, in-range rows. - - Feeds identifiable rows (each row's value == its global stream index), - then asserts every reservoir slot still holds an intact fed row, all - indices are in range, and replacement past the initial fill actually - happened — exercising the accept-prob / slot-write logic that the - count/shape-only ``test_reservoir_caps_memory`` cannot. - """ - torch.manual_seed(0) - input_dim, cap, B, n_batches = 4, 8, 4, 50 - model = self._create_model(input_dim=input_dim) - model._sample_cap = cap - model._reset_reservoir() - model.train() - - gidx = 0 - for _ in range(n_batches): - rows = ( - torch.arange(gidx, gidx + B, dtype=torch.float32) - .unsqueeze(1) - .expand(B, input_dim) - .contiguous() - ) - gidx += B - model.predict(_batch_from_rows(rows)) - - total = B * n_batches - self.assertEqual(model._n_seen, total) - self.assertEqual(model._n_filled, cap) - - res = model._reservoir - idx = res[:, 0].round().long() - # Each stored row is an intact fed row (all columns equal its index), - # never zeros/garbage. - self.assertTrue( - torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), - "reservoir holds corrupted (non-fed) rows", - ) - # All indices are valid stream positions. - self.assertTrue((idx >= 0).all() and (idx < total).all()) - # Phase-2 replacement happened: at least one slot holds a row added - # after the reservoir filled (index >= cap). - self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") - def test_on_train_end_runs_faiss(self) -> None: """on_train_end triggers FAISS fit and clears buffer.""" try: @@ -194,14 +138,14 @@ def test_on_train_end_runs_faiss(self) -> None: # Accumulate enough samples (FAISS K-Means needs at least K points) for _ in range(8): model.predict(_make_batch(B, input_dim)) - self.assertGreater(model._n_seen, 0) + self.assertGreater(model._reservoir.n_seen, 0) # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint self.assertTrue(model.on_train_end()) # Reservoir should be released after the fit - self.assertEqual(model._n_seen, 0) - self.assertIsNone(model._reservoir) + self.assertEqual(model._reservoir.n_seen, 0) + self.assertEqual(model._reservoir.n_filled, 0) # All layers should be initialized + centroids non-zero for layer in model._quantizer.layers: self.assertTrue(bool(layer._is_initialized.item())) @@ -226,7 +170,7 @@ def test_non_uniform_codebook_end_to_end(self) -> None: model = self._create_model(input_dim=input_dim, codebook=codebook) # Reservoir cap derives from the LARGEST K (16), not the first (8). self.assertEqual( - model._sample_cap, + model._reservoir.capacity, 16 * int(model._faiss_kwargs.get("max_points_per_centroid", 256)), ) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 02cfc63d6..629392fc4 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -18,9 +18,12 @@ :class:`ResidualKMeansQuantizer`. Centroids are injected by the FAISS backend via ``load_centroids_``; the only forward path is ``predict``. +* :class:`ReservoirSampler` — bounded uniform stream sample (Vitter + Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` + fills during training to feed the one-shot FAISS fit. """ -from typing import Tuple +from typing import Optional, Tuple import torch from torch import nn @@ -52,6 +55,93 @@ def recon_diagnostics( return mse, rel +class ReservoirSampler: + """Bounded uniform sample of a stream (Vitter Algorithm R). + + Keeps a uniform ``capacity``-row sample of all rows passed to ``add``, in + O(capacity) host (CPU) memory — used to subsample the training corpus for + the one-shot FAISS fit without buffering the whole corpus. The buffer is a + CPU float32 tensor, allocated lazily on the first ``add``. + + Args: + capacity (int): max rows retained. + dim (int): row width (feature dimension). + """ + + def __init__(self, capacity: int, dim: int) -> None: + self._cap = capacity + self._dim = dim + # Allocated lazily on the first add. _n_filled = used slots; + # _n_seen = running count for the accept prob. + self._buf: Optional[torch.Tensor] = None + self._n_filled = 0 + self._n_seen = 0 + + @property + def capacity(self) -> int: + """Max rows retained.""" + return self._cap + + @property + def n_seen(self) -> int: + """Total rows passed to ``add`` so far.""" + return self._n_seen + + @property + def n_filled(self) -> int: + """Rows currently held (<= capacity).""" + return self._n_filled + + @torch.no_grad() + def add(self, x: torch.Tensor) -> None: + """Stream a batch of rows into the reservoir. + + Args: + x (Tensor): rows to add, shape (B, dim). + """ + x = x.detach() + cap = self._cap + if self._buf is None: + self._buf = torch.empty(cap, self._dim, dtype=torch.float32) + + # Phase 1: fill empty slots first. x is already on the host (CPU-only + # model), so this is a dtype cast into the buffer, not a device copy. + if self._n_filled < cap: + take = min(x.shape[0], cap - self._n_filled) + self._buf[self._n_filled : self._n_filled + take] = x[:take].to( + torch.float32 + ) + self._n_filled += take + self._n_seen += take + x = x[take:] + if x.shape[0] == 0: + return + + # Phase 2: row j enters with prob cap/(n_seen+j+1), displacing a random + # slot. float64 keeps n_seen+j+1 exact past 2**24. + r = x.shape[0] + pos = self._n_seen + torch.arange(r) + accept = torch.rand(r) < (cap / (pos + 1).to(torch.float64)) + idx = accept.nonzero(as_tuple=True)[0] + if idx.numel() > 0: + slots = torch.randint(0, cap, (idx.numel(),)) + # Slot collisions are last-write-wins; O(B/cap) bias, negligible here. + self._buf[slots] = x[idx].to(torch.float32) + self._n_seen += r + + def sample(self) -> torch.Tensor: + """Return the filled portion of the reservoir, shape (n_filled, dim).""" + if self._buf is None or self._n_filled == 0: + return torch.empty(0, self._dim, dtype=torch.float32) + return self._buf[: self._n_filled] + + def reset(self) -> None: + """Drop the buffer and counters to free host memory.""" + self._buf = None + self._n_filled = 0 + self._n_seen = 0 + + class KMeansLayer(nn.Module): """Single layer of a residual K-Means stack. diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index 1b21604d3..d6b06a7f1 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -15,6 +15,7 @@ from tzrec.modules.sid.kmeans import ( KMeansLayer, + ReservoirSampler, recon_diagnostics, ) @@ -70,5 +71,77 @@ def test_post_fit_checkpoint_round_trips(self) -> None: torch.testing.assert_close(fresh.centroids, layer.centroids) +class ReservoirSamplerTest(unittest.TestCase): + """Tests for the bounded reservoir sampler (Vitter Algorithm R).""" + + def test_empty_sample(self) -> None: + """sample() before any add returns an empty (0, dim) tensor.""" + r = ReservoirSampler(capacity=10, dim=4) + self.assertEqual(r.sample().shape, (0, 4)) + self.assertEqual(r.n_seen, 0) + self.assertEqual(r.n_filled, 0) + + def test_caps_memory(self) -> None: + """The buffer is bounded at capacity regardless of stream length.""" + cap, dim, B = 10, 8, 16 + r = ReservoirSampler(capacity=cap, dim=dim) + for _ in range(20): # 320 rows >> cap + r.add(torch.randn(B, dim)) + self.assertEqual(r.n_seen, 20 * B) + self.assertEqual(r.n_filled, cap) + self.assertEqual(r.sample().shape, (cap, dim)) + + def test_phase2_replacement(self) -> None: + """Phase-2 replacement keeps a valid sample of real, in-range rows. + + Feeds identifiable rows (each row's value == its global stream index), + then asserts every slot still holds an intact fed row, all indices are + in range, and replacement past the initial fill actually happened — + exercising the accept-prob / slot-write logic that the count/shape-only + ``test_caps_memory`` cannot. + """ + torch.manual_seed(0) + dim, cap, B, n_batches = 4, 8, 4, 50 + r = ReservoirSampler(capacity=cap, dim=dim) + + gidx = 0 + for _ in range(n_batches): + rows = ( + torch.arange(gidx, gidx + B, dtype=torch.float32) + .unsqueeze(1) + .expand(B, dim) + .contiguous() + ) + gidx += B + r.add(rows) + + total = B * n_batches + self.assertEqual(r.n_seen, total) + self.assertEqual(r.n_filled, cap) + + res = r.sample() + idx = res[:, 0].round().long() + # Each stored row is an intact fed row (all columns equal its index). + self.assertTrue( + torch.equal(res, idx.unsqueeze(1).float().expand_as(res)), + "reservoir holds corrupted (non-fed) rows", + ) + # All indices are valid stream positions. + self.assertTrue((idx >= 0).all() and (idx < total).all()) + # Phase-2 replacement happened: at least one slot holds a row added + # after the reservoir filled (index >= cap). + self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + + def test_reset(self) -> None: + """reset() drops the buffer and counters.""" + r = ReservoirSampler(capacity=10, dim=4) + r.add(torch.randn(5, 4)) + self.assertEqual(r.n_filled, 5) + r.reset() + self.assertEqual(r.n_seen, 0) + self.assertEqual(r.n_filled, 0) + self.assertEqual(r.sample().shape, (0, 4)) + + if __name__ == "__main__": unittest.main() From e296c8d32e629b021e6cd9f7947b9cc2c0e609f3 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:30:08 +0000 Subject: [PATCH 060/129] [refactor] ReservoirSampler: log capacity + dim on construction Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 629392fc4..7c89f7d17 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -28,6 +28,8 @@ import torch from torch import nn +from tzrec.utils.logging_util import logger + def recon_diagnostics( x: torch.Tensor, @@ -76,6 +78,7 @@ def __init__(self, capacity: int, dim: int) -> None: self._buf: Optional[torch.Tensor] = None self._n_filled = 0 self._n_seen = 0 + logger.info("[ReservoirSampler] capacity=%d, dim=%d", capacity, dim) @property def capacity(self) -> int: From 892a8d26538fd816b715170a17847f2bcc5d3c55 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:44:41 +0000 Subject: [PATCH 061/129] [fix] SID code-review: fail-fast cap, skip pre-fit eval, dedup MSE, drop x0 clone Addresses /code-review findings on SidRqkmeans: #1 Fail fast at __init__ when the reservoir cap < max(codebook) (an explicit train_sample_size too small would otherwise assert in train_offline only at on_train_end, after the whole training pass). #2 update_metric returns early when the codebook isn't fitted yet (ResidualKMeansQuantizer.is_fitted), so in-loop eval before the end-of-train FAISS fit no longer logs garbage mse/rel_loss/unique_sid_ratio over the all-zero codebook. #3 Stop computing MSE twice per eval batch: extract a relative_l1 helper (recon_diagnostics now reuses it) and call it directly instead of recon_diagnostics-then-discard-mse alongside MeanSquaredError.update. #4 Drop the persistent ~0.5GB x0 clone in train_offline for the common (normalize_residuals=False) path: reconstruct the per-layer log reference on the fly via the out + x == x0 invariant; keep the clone only when normalization breaks that invariant. Tests: add fail-fast and pre-fit-skip cases. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 25 +++++++++++++++---- tzrec/models/sid_rqkmeans_test.py | 16 ++++++++++++ tzrec/modules/sid/kmeans.py | 23 ++++++++++++++--- .../modules/sid/residual_kmeans_quantizer.py | 23 +++++++++++++---- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 3f3b6954f..a79250bac 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -27,7 +27,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler, recon_diagnostics +from tzrec.modules.sid.kmeans import ReservoirSampler, relative_l1 from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -116,6 +116,14 @@ def __init__( # so no per-rank split. target = self._model_config.train_sample_size cap = target if target > 0 else self._quantizer.default_fit_sample_size() + # Fail fast: FAISS needs >= K points to fit each layer, so a cap below + # the largest codebook would only assert at on_train_end — after the + # whole training pass. (The default cap is always >= max(K).) + max_k = max(self._n_embed_list) + assert cap >= max_k, ( + f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " + f"train_sample_size >= {max_k} (or 0 for the default)." + ) self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. @@ -207,13 +215,20 @@ def update_metric( batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ + # In-loop eval can run before the end-of-train FAISS fit; the codebook + # is all-zeros then, so codes/reconstruction are meaningless. Skip until + # fitted so those bogus values don't pollute the eval metrics. + if not self._quantizer.is_fitted: + return + if "quantized" in predictions: embedding = self._extract_feature(batch) - _, rel = recon_diagnostics(embedding, predictions["quantized"]) - # mse aggregates (preds, target) itself; rel_loss has no - # torchmetrics equivalent, so it stays a MeanMetric. + # mse aggregates (preds, target) itself; rel_loss has no torchmetrics + # equivalent, so compute it directly (only rel is needed here). self._metric_modules["mse"].update(predictions["quantized"], embedding) - self._metric_modules["rel_loss"].update(rel) + self._metric_modules["rel_loss"].update( + relative_l1(embedding, predictions["quantized"]) + ) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index c312ebad3..e41fba295 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -105,6 +105,11 @@ def test_sample_cap_from_train_sample_size(self) -> None: model._reservoir.capacity, model._quantizer.default_fit_sample_size() ) + def test_init_raises_on_too_small_train_sample_size(self) -> None: + """train_sample_size below the largest codebook fails fast at init.""" + with self.assertRaisesRegex(AssertionError, "largest codebook"): + self._create_model(codebook=[16, 16], train_sample_size=8) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 @@ -270,6 +275,17 @@ def test_eval_metric_path(self) -> None: self.assertIn(key, metrics) self.assertTrue(torch.isfinite(torch.as_tensor(metrics[key])).all()) + def test_update_metric_skipped_before_fit(self) -> None: + """Pre-fit eval (unfitted codebook) does not pollute metric state.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim) + model.init_metric() + model.eval() + # Codebook not fitted yet: predict emits zeros; update_metric must skip. + batch = _make_batch(B, input_dim) + model.update_metric(model.predict(batch), batch) + self.assertEqual(model._metric_modules["unique_sid_ratio"].count.item(), 0.0) + def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 7c89f7d17..50b2263f3 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -50,11 +50,28 @@ def recon_diagnostics( mse: scalar ``((out - x) ** 2).mean()``. rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. """ - mse = ((out - x) ** 2).mean() - rel = ( + return ((out - x) ** 2).mean(), relative_l1(x, out, epsilon) + + +def relative_l1( + x: torch.Tensor, + out: torch.Tensor, + epsilon: float = 1e-4, +) -> torch.Tensor: + """Relative-L1 error ``mean(|x - out| / (max(|x|, |out|) + eps))``. + + Symmetric relative error in [0, 1] (verbatim port of OpenOneRec's + ``calc_loss``). Used standalone by :meth:`SidRqkmeans.update_metric` (which + needs only ``rel``, not the MSE :meth:`recon_diagnostics` also computes). + + Args: + x: ground-truth embedding, shape (B, D). + out: quantized reconstruction, shape (B, D). + epsilon: numerical stabilizer for the denominator. + """ + return ( torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) ).mean() - return mse, rel class ReservoirSampler: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 1bfe20267..8a0c8a176 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -127,6 +127,15 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: cluster_ids, quantized_sum, _ = self._residual_pass(input) return cluster_ids, quantized_sum + @property + def is_fitted(self) -> bool: + """Whether ``train_offline`` has populated every layer's codebook. + + ``forward`` is callable before the fit (uninitialized layers emit + zeros), so reconstruction outputs are meaningful only once this is True. + """ + return all(layer.is_initialized for layer in self.layers) + @torch.no_grad() def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get centroid weights for a specific layer. @@ -189,10 +198,12 @@ def train_offline( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" ) out = torch.zeros_like(x) - # Original input, kept only for the log: the per-layer diagnostic is the - # cumulative recon error of x0 by the centroid sum (what update_metric - # reports). ``out + x`` would equal it only without normalization. - x0 = x.clone() if verbose else None + # The per-layer log reports the cumulative recon error of the original + # input x0 by the centroid sum. Without normalization the invariant + # ``out + x == x0`` holds, so x0 is reconstructed on the fly below and we + # skip the persistent (N, D) clone; with normalization x is rescaled each + # layer, breaking the invariant, so the clone is required. + x0 = x.clone() if (verbose and self.normalize_residuals) else None # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, # so the codebook is always built on CPU. Drop any stale ``gpu`` request @@ -233,10 +244,12 @@ def train_offline( del idx, q if verbose: + # x0 == out + x without normalization (see above). + ref = x0 if x0 is not None else out + x logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, - self._calc_loss(x0, out), # cumulative recon of original input + self._calc_loss(ref, out), # cumulative recon of original input ) self.layers[layer_idx].load_centroids_(centroids) From b14304af2a6fae7c29cd17ddd46d5f0b9930c129 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:50:03 +0000 Subject: [PATCH 062/129] [simplify] SID: raise (not assert) for cap guard; name normalize_residuals - __init__: the cap < max(codebook) fail-fast used a bare assert, which is stripped under `python -O` (defeating the fail-fast purpose) and was inconsistent with the two sibling raise-guards in the same constructor. Convert to raise RuntimeError; update the test accordingly. - train_offline: `ref = x0 if self.normalize_residuals else out + x` names the actual reason instead of re-deriving it via the `x0 is not None` sentinel. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 9 +++++---- tzrec/models/sid_rqkmeans_test.py | 2 +- tzrec/modules/sid/residual_kmeans_quantizer.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index a79250bac..af9b5b7e7 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -120,10 +120,11 @@ def __init__( # the largest codebook would only assert at on_train_end — after the # whole training pass. (The default cap is always >= max(K).) max_k = max(self._n_embed_list) - assert cap >= max_k, ( - f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " - f"train_sample_size >= {max_k} (or 0 for the default)." - ) + if cap < max_k: + raise RuntimeError( + f"reservoir cap ({cap}) < largest codebook size ({max_k}); set " + f"train_sample_size >= {max_k} (or 0 for the default)." + ) self._reservoir = ReservoirSampler(cap, self._input_dim) # KMeans has no learnable params; a dummy keeps the optimizer/DDP happy. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index e41fba295..782991eac 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -107,7 +107,7 @@ def test_sample_cap_from_train_sample_size(self) -> None: def test_init_raises_on_too_small_train_sample_size(self) -> None: """train_sample_size below the largest codebook fails fast at init.""" - with self.assertRaisesRegex(AssertionError, "largest codebook"): + with self.assertRaisesRegex(RuntimeError, "largest codebook"): self._create_model(codebook=[16, 16], train_sample_size=8) def test_predict_collects_buffer(self) -> None: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 8a0c8a176..9816341e9 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -245,7 +245,7 @@ def train_offline( if verbose: # x0 == out + x without normalization (see above). - ref = x0 if x0 is not None else out + x + ref = x0 if self.normalize_residuals else out + x logger.info( "[ResidualKMeansQuantizer][offline_faiss][layer %d] %s", layer_idx, From eb39b5e294b3a71599f871b8d505777e5393addf Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 03:53:02 +0000 Subject: [PATCH 063/129] [style] SID: trim verbose comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten the multi-line block comments added across the recent SID work (CPU-only/single-process guards, reservoir cap, x0 invariant, gpu-kwarg drop, reservoir Phase-1, update_metric) — keep the load-bearing "why", drop the over-explanation the error messages and code already convey. Comments only. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 27 +++++++------------ tzrec/modules/sid/kmeans.py | 4 +-- .../modules/sid/residual_kmeans_quantizer.py | 14 +++++----- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index af9b5b7e7..8e181b1c8 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -72,19 +72,16 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: everything (embeddings, reservoir, FAISS fit) stays on the - # host, so there are no device copies on the train path. Refuse to run - # when CUDA is visible rather than silently shuttling tensors to/from a - # GPU; launch with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host). + # CPU-only: embeddings, reservoir, and FAISS fit all stay on the host, + # so there are no device copies. Refuse to run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' ) - # Single-process only: the FAISS fit runs on one process over its local - # reservoir, with no cross-rank gather/broadcast. Fail fast here rather - # than after a full (wasted) training pass. + # Single-process only: the fit runs over one process's local reservoir, + # with no cross-rank gather. Fail fast before the (wasted) train pass. if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: raise RuntimeError( "SidRqkmeans supports single-process training only " @@ -109,16 +106,13 @@ def __init__( faiss_kmeans_kwargs=self._faiss_kwargs, ) - # Bounded host reservoir for the end-of-loop FAISS fit: cap at - # ``train_sample_size`` when set (>0), else the points the FAISS fit - # subsamples to (``default_fit_sample_size``) — rather than buffer the - # whole corpus. Single-process only (see the world_size guard above), - # so no per-rank split. + # Bounded host reservoir for the end-of-loop fit: cap at + # ``train_sample_size`` (when >0) else the fit's subsample size, rather + # than buffer the whole corpus. target = self._model_config.train_sample_size cap = target if target > 0 else self._quantizer.default_fit_sample_size() - # Fail fast: FAISS needs >= K points to fit each layer, so a cap below - # the largest codebook would only assert at on_train_end — after the - # whole training pass. (The default cap is always >= max(K).) + # Fail fast: a cap below the largest codebook would only fail deep in + # train_offline, after the whole training pass. max_k = max(self._n_embed_list) if cap < max_k: raise RuntimeError( @@ -224,8 +218,7 @@ def update_metric( if "quantized" in predictions: embedding = self._extract_feature(batch) - # mse aggregates (preds, target) itself; rel_loss has no torchmetrics - # equivalent, so compute it directly (only rel is needed here). + # rel_loss has no torchmetrics equivalent, so compute it directly. self._metric_modules["mse"].update(predictions["quantized"], embedding) self._metric_modules["rel_loss"].update( relative_l1(embedding, predictions["quantized"]) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 50b2263f3..11df2b65e 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -124,8 +124,8 @@ def add(self, x: torch.Tensor) -> None: if self._buf is None: self._buf = torch.empty(cap, self._dim, dtype=torch.float32) - # Phase 1: fill empty slots first. x is already on the host (CPU-only - # model), so this is a dtype cast into the buffer, not a device copy. + # Phase 1: fill empty slots first. x is on the host, so ``.to`` is a + # dtype cast into the buffer, not a device copy. if self._n_filled < cap: take = min(x.shape[0], cap - self._n_filled) self._buf[self._n_filled : self._n_filled + take] = x[:take].to( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 9816341e9..29ad037d1 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -198,16 +198,14 @@ def train_offline( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" ) out = torch.zeros_like(x) - # The per-layer log reports the cumulative recon error of the original - # input x0 by the centroid sum. Without normalization the invariant - # ``out + x == x0`` holds, so x0 is reconstructed on the fly below and we - # skip the persistent (N, D) clone; with normalization x is rescaled each - # layer, breaking the invariant, so the clone is required. + # x0 (original input) feeds the per-layer recon log. Without + # normalization ``out + x == x0``, so it's rebuilt on the fly below and + # the persistent (N, D) clone is skipped; normalization rescales x and + # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit: SidRqkmeans refuses to initialize when CUDA is visible, - # so the codebook is always built on CPU. Drop any stale ``gpu`` request - # from the config so a faiss-gpu build can't try to use an absent GPU. + # CPU-only fit (SidRqkmeans refuses CUDA). Drop any stale ``gpu`` kwarg + # so a faiss-gpu build can't target an absent GPU. kwargs = dict(self.faiss_kmeans_kwargs) kwargs.pop("gpu", None) if verbose: From 8bf50aa857674606ed89cff19e184e16451f0ffb Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 04:41:40 +0000 Subject: [PATCH 064/129] [refactor] SID: move init_metric/update_metric to BaseSidModel + RelativeL1 metric MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses maintainer review #3. - New custom torchmetrics metric RelativeL1 (tzrec/metrics/relative_l1.py): symmetric relative-L1 |t-p|/(max(|t|,|p|)+eps), count-weighted aggregation. A proper Metric class (like UniqueRatio), NOT torchmetrics MeanAbsolutePercentageError — MAPE's asymmetric |t-p|/|t| denominator differs from OpenOneRec's calc_loss, which this is a verbatim port of. - BaseSidModel now owns init_metric (mse + rel_loss + unique_sid_ratio) and a generic update_metric that re-extracts the target embedding and gates all eval metrics on a non-None _reconstruction() hook (so a not-yet-fitted model logs nothing). - SidRqkmeans drops its init_metric/update_metric overrides and implements _reconstruction() -> quantized (or None until the FAISS fit), inheriting the shared metric logic. Drop now-unused torchmetrics / relative_l1 imports. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 50 ++++++++++++++++++++++++ tzrec/metrics/relative_l1_test.py | 49 +++++++++++++++++++++++ tzrec/models/sid_model.py | 65 +++++++++++++++++++++++++++---- tzrec/models/sid_rqkmeans.py | 57 +++++++-------------------- 4 files changed, 169 insertions(+), 52 deletions(-) create mode 100644 tzrec/metrics/relative_l1.py create mode 100644 tzrec/metrics/relative_l1_test.py diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py new file mode 100644 index 000000000..72a55c28d --- /dev/null +++ b/tzrec/metrics/relative_l1.py @@ -0,0 +1,50 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torchmetrics import Metric + + +class RelativeL1(Metric): + """Mean symmetric relative-L1 error ``|t - p| / (max(|t|, |p|) + eps)``. + + A bounded reconstruction-error metric (0 = exact, → 1 = unrelated). It is a + verbatim port of OpenOneRec's residual-K-Means ``calc_loss`` and is + deliberately **not** ``torchmetrics.MeanAbsolutePercentageError``, which uses + the asymmetric ``|t - p| / |t|`` denominator. Aggregation is element-wise + (count-weighted), so the reported value is the mean over all elements seen. + """ + + higher_is_better = False + is_differentiable = True + + def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: + super().__init__(**kwargs) + self.epsilon = epsilon + self.add_state("sum_rel", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Accumulate the relative-L1 error between ``preds`` and ``target``. + + Args: + preds (Tensor): reconstruction, shape (B, D). + target (Tensor): ground-truth embedding, shape (B, D). + """ + rel = torch.abs(target - preds) / ( + torch.maximum(torch.abs(target), torch.abs(preds)) + self.epsilon + ) + self.sum_rel += rel.sum() + self.count += rel.numel() + + def compute(self) -> torch.Tensor: + """Mean relative-L1 over all elements (NaN before any update).""" + return self.sum_rel / self.count diff --git a/tzrec/metrics/relative_l1_test.py b/tzrec/metrics/relative_l1_test.py new file mode 100644 index 000000000..0f89c2ccd --- /dev/null +++ b/tzrec/metrics/relative_l1_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.metrics.relative_l1 import RelativeL1 + + +class RelativeL1Test(unittest.TestCase): + def test_zero_on_identity(self) -> None: + metric = RelativeL1() + x = torch.randn(8, 4) + metric.update(x, x.clone()) + self.assertAlmostEqual(metric.compute().item(), 0.0, places=6) + + def test_matches_formula(self) -> None: + metric = RelativeL1(epsilon=1e-4) + p = torch.tensor([[1.0, 0.0]]) + t = torch.tensor([[0.0, 2.0]]) + # |t-p|/(max(|t|,|p|)+eps): [1/(1+eps), 2/(2+eps)], mean of the two. + expected = (1.0 / (1.0 + 1e-4) + 2.0 / (2.0 + 1e-4)) / 2 + metric.update(p, t) + self.assertAlmostEqual(metric.compute().item(), expected, places=5) + + def test_count_weighted_across_updates(self) -> None: + """Aggregation is element-wise, not a mean of per-batch means.""" + metric = RelativeL1() + metric.update(torch.zeros(1, 4), torch.ones(1, 4)) # 4 elems, rel ~1 + metric.update(torch.ones(3, 4), torch.ones(3, 4)) # 12 elems, rel 0 + # Element-weighted: 4 nonzero over 16 elems -> ~0.25, NOT (1+0)/2 = 0.5. + per = 1.0 / (1.0 + 1e-4) # rel of a 0-vs-1 element (with epsilon) + self.assertAlmostEqual(metric.compute().item(), 4 * per / 16, places=6) + + def test_nan_before_update(self) -> None: + self.assertTrue(torch.isnan(RelativeL1().compute())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 973fcf99f..51fd9a179 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -18,6 +18,7 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature +from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel from tzrec.protos.model_pb2 import ModelConfig @@ -39,10 +40,10 @@ class BaseSidModel(BaseModel): proxy). Subclasses build their quantizer in ``__init__`` (after calling - ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. - They extend :meth:`init_metric` (via ``super()``) and implement - :meth:`update_metric` to populate the registered metrics - (:meth:`update_train_metric` defaults to a no-op). + ``super().__init__``) and implement :meth:`predict`, :meth:`loss`, and + :meth:`_reconstruction` (which exposes the model's reconstruction of the + input embedding for the shared :meth:`update_metric`). + (:meth:`update_train_metric` defaults to a no-op.) Args: model_config (ModelConfig): an instance of ModelConfig. @@ -99,14 +100,62 @@ def init_loss(self) -> None: def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. - ``mse``: reconstruction error (input vs. quantized / decoded). - ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows / - batch size; a batch-size-sensitive diversity proxy, not global - coverage). Subclasses call ``super().init_metric()`` then add extras. + - ``mse``: reconstruction error (input vs. quantized / decoded). + - ``rel_loss``: symmetric relative-L1 reconstruction error + (:class:`~tzrec.metrics.relative_l1.RelativeL1`); meaningful only with + ``normalize_residuals=False`` (else the reconstruction and the input + live on different scales). + - ``unique_sid_ratio``: mean per-batch unique-SID ratio (distinct rows / + batch size; a batch-size-sensitive diversity proxy, not global + coverage). + + Subclasses that add extras call ``super().init_metric()`` first. """ self._metric_modules["mse"] = torchmetrics.MeanSquaredError() + self._metric_modules["rel_loss"] = RelativeL1() self._metric_modules["unique_sid_ratio"] = UniqueRatio() + def _reconstruction( + self, predictions: Dict[str, torch.Tensor] + ) -> Optional[torch.Tensor]: + """The model's reconstruction of the input embedding, or None. + + Returns the (B, D) tensor that ``mse``/``rel_loss`` compare against the + input embedding — e.g. ``predictions["quantized"]`` (RQ-KMeans) or + ``predictions["x_hat"]`` (RQ-VAE). Returns None when it is unavailable or + not yet meaningful this step (e.g. before a K-Means fit), in which case + :meth:`update_metric` skips the eval metrics entirely. + + Args: + predictions (dict): a dict of predicted result. + """ + raise NotImplementedError + + def update_metric( + self, + predictions: Dict[str, torch.Tensor], + batch: Batch, + losses: Optional[Dict[str, torch.Tensor]] = None, + ) -> None: + """Update eval metrics from a reconstruction + the re-extracted input. + + The target embedding is re-extracted from ``batch`` (it is an input, not + a model output). All three metrics are gated on a non-None + :meth:`_reconstruction` so a not-yet-fitted model does not log garbage. + + Args: + predictions (dict): a dict of predicted result. + batch (Batch): input batch data. + losses (dict, optional): a dict of loss. + """ + recon = self._reconstruction(predictions) + if recon is None: + return + embedding = self._extract_feature(batch) + self._metric_modules["mse"].update(recon, embedding) + self._metric_modules["rel_loss"].update(recon, embedding) + self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + def update_train_metric( self, predictions: Dict[str, torch.Tensor], diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 8e181b1c8..d8fd2d677 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -21,13 +21,12 @@ import torch import torch.distributed as dist -import torchmetrics from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler, relative_l1 +from tzrec.modules.sid.kmeans import ReservoirSampler from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -177,54 +176,24 @@ def loss( """ return {"dummy_loss": self._dummy_param.sum() * 0.0} - def init_metric(self) -> None: - """Register eval metrics (shared ``mse`` + ``rel_loss``). + def _reconstruction( + self, predictions: Dict[str, torch.Tensor] + ) -> Optional[torch.Tensor]: + """Centroid-sum reconstruction, or None until the codebook is fit. - Train-time metrics are intentionally absent: ``predict`` returns dummy - codes pre-fit, so the inherited no-op ``update_train_metric`` keeps the - train path empty. - """ - super().init_metric() - self._metric_modules["rel_loss"] = torchmetrics.MeanMetric() - - def update_metric( - self, - predictions: Dict[str, torch.Tensor], - batch: Batch, - losses: Optional[Dict[str, torch.Tensor]] = None, - ) -> None: - """Update metric state. - - The reconstruction target (the input embedding) is re-extracted from - ``batch`` — it is an input, not a model output. ``quantized`` is present - only in eval (see ``predict``), so this runs eval-only. - - Note: ``mse``/``rel_loss`` compare that embedding against the centroid-sum - reconstruction. They are meaningful reconstruction metrics only with - ``normalize_residuals=False`` (the default); with normalization the - centroids live on the rescaled-residual scale, so the two quantities - don't share a scale (same caveat the train_offline per-layer log carries). + ``quantized`` is present only in eval and is all-zeros before the + end-of-train FAISS fit, so gate on the fit — the shared + :meth:`BaseSidModel.update_metric` then skips the eval metrics until the + reconstruction is meaningful. (Meaningful only with + ``normalize_residuals=False``; with normalization the centroids live on + the rescaled-residual scale, so the two quantities don't share a scale.) Args: predictions (dict): a dict of predicted result. - batch (Batch): input batch data. - losses (dict, optional): a dict of loss. """ - # In-loop eval can run before the end-of-train FAISS fit; the codebook - # is all-zeros then, so codes/reconstruction are meaningless. Skip until - # fitted so those bogus values don't pollute the eval metrics. if not self._quantizer.is_fitted: - return - - if "quantized" in predictions: - embedding = self._extract_feature(batch) - # rel_loss has no torchmetrics equivalent, so compute it directly. - self._metric_modules["mse"].update(predictions["quantized"], embedding) - self._metric_modules["rel_loss"].update( - relative_l1(embedding, predictions["quantized"]) - ) - - self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + return None + return predictions.get("quantized") @torch.no_grad() def on_train_end(self) -> bool: From e8a3609dcac60ad07e79c242812dabed4ca0121e Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 04:54:39 +0000 Subject: [PATCH 065/129] [test] SID: add sid_integration_test (train -> fit -> checkpoint -> eval) Addresses maintainer review #2 (integration test in tzrec/tests/, like the match/rank integration tests). Drives a real train_eval -> eval over a tiny prepared embedding parquet and asserts on_train_end forced a final checkpoint (the codebook exists only after the fit) and a post-fit eval_result was written. Because SidRqkmeans is CPU-only + single-process, the test forces CUDA_VISIBLE_DEVICES="" and TEST_NPROC_PER_NODE=1 (the harness otherwise defaults to GPU + nproc=2). Verified passing on the DSW remote (torchrec 1.6); the local container can't run train_eval (torchrec 1.5). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/tests/configs/sid_rqkmeans_mock.config | 54 ++++++++++ tzrec/tests/sid_integration_test.py | 105 +++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 tzrec/tests/configs/sid_rqkmeans_mock.config create mode 100644 tzrec/tests/sid_integration_test.py diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config new file mode 100644 index 000000000..d473dd705 --- /dev/null +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -0,0 +1,54 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqkmeans_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + sid_rqkmeans { + input_dim: 16 + codebook: 16 + codebook: 16 + codebook: 16 + normalize_residuals: false + embedding_feature_name: "item_emb" + faiss_kmeans_kwargs { + fields { key: "niter" value { number_value: 5 } } + fields { key: "seed" value { number_value: 42 } } + } + } +} diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py new file mode 100644 index 000000000..711e69ec0 --- /dev/null +++ b/tzrec/tests/sid_integration_test.py @@ -0,0 +1,105 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import shutil +import tempfile +import unittest +from unittest import mock + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from tzrec.tests import utils +from tzrec.utils import config_util + + +class SidIntegrationTest(unittest.TestCase): + def setUp(self): + self.success = False + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") + os.chmod(self.test_dir, 0o755) + # SID models are CPU-only (refuse a visible CUDA device) and + # single-process (refuse world_size > 1), so hide CUDA and pin + # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. + patcher = mock.patch.dict( + os.environ, {"CUDA_VISIBLE_DEVICES": "", "TEST_NPROC_PER_NODE": "1"} + ) + patcher.start() + self.addCleanup(patcher.stop) + + def tearDown(self): + if self.success and os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def _prepare_config(self, num_rows: int, dim: int) -> str: + """Write an embedding parquet + a SID config pointed at it. + + Single dense ``embedding`` column, no labels — SID reads the item + embedding straight from the batch. Returns the saved config path. + """ + data_dir = os.path.join(self.test_dir, "sid_data") + os.makedirs(data_dir, exist_ok=True) + emb = np.random.rand(num_rows, dim).astype(np.float32) + pq.write_table( + pa.table({"embedding": pa.array(list(emb))}), + os.path.join(data_dir, "part-0.parquet"), + ) + data_glob = os.path.join(data_dir, "*.parquet") + + # train_input_path set -> load_config_for_test uses it as-is (the + # FG_DAG auto-mock path is match-model-specific; SID is single-table). + config = config_util.load_pipeline_config( + "tzrec/tests/configs/sid_rqkmeans_mock.config" + ) + config.train_input_path = data_glob + config.eval_input_path = data_glob + config_path = os.path.join(self.test_dir, "sid.config") + config_util.save_message(config, config_path) + return config_path + + def test_sid_rqkmeans_train_eval(self): + """End-to-end train -> on_train_end FAISS fit -> checkpoint -> eval. + + Locks down the load-bearing path: the codebook exists only after + ``on_train_end``, which forces the final checkpoint; the post-fit eval + then reports finite reconstruction metrics. + """ + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + config_path = self._prepare_config(num_rows=2048, dim=16) + + self.success = utils.test_train_eval(config_path, self.test_dir) + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + # on_train_end fitted the codebook and forced a final checkpoint. + self.assertTrue( + glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), + "no checkpoint persisted after on_train_end", + ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "train", "eval_result.txt")), + "no eval_result.txt produced", + ) + + +if __name__ == "__main__": + unittest.main() From 3dfbde07693018ff0732e640045bbe5ba653c919 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 05:17:49 +0000 Subject: [PATCH 066/129] [test] checkpoint: verify force re-save overwrites the same step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test_force_overwrite_same_step: save step 5 (centroids=0), then re-save the SAME step with different params (centroids=7) — assert a non-force re-save dedupes (no overwrite) while a force re-save overwrites, and the reloaded checkpoint holds the later params. This is the on_train_end post-fit path: a periodic save at the final step, then a forced re-save of the fitted codebook at the same step. Verified on the DSW remote (torchrec 1.6); the local container can't import checkpoint_util (torchrec 1.5). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util_test.py | 59 +++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tzrec/utils/checkpoint_util_test.py b/tzrec/utils/checkpoint_util_test.py index 8fc6130f4..6f3b38757 100644 --- a/tzrec/utils/checkpoint_util_test.py +++ b/tzrec/utils/checkpoint_util_test.py @@ -171,6 +171,52 @@ def _remap_restore_worker(test_dir, rank, world_size, port, remap_file_path): shard_w_2_m2.gather(0) +def _force_overwrite_worker(test_dir, rank, world_size, port): + """force=True re-save at an already-saved step must overwrite it. + + Saves step 5 with centroids=0, then re-saves the SAME step with different + params (centroids=7): a non-force re-save dedupes (no overwrite), a force + re-save overwrites. Reloads and asserts the persisted step-5 checkpoint + holds the later params. (This is the on_train_end post-fit checkpoint path: + a periodic save at the final step, then a forced re-save of the fitted + codebook at the same step.) + """ + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + dist.init_process_group(backend="gloo") + + class BufModel(nn.Module): + def __init__(self, fill): + super().__init__() + self.register_buffer("centroids", torch.full((4, 3), float(fill))) + + manager = checkpoint_util.CheckpointManager(test_dir, keep_checkpoint_max=0) + + # Initial save at step 5 (pre-fit: centroids = 0). + assert manager.maybe_save(5, BufModel(0.0), final=True), "initial save" + + # Same step, different params: non-force dedupes; force overwrites. + model = BufModel(7.0) + assert not manager.maybe_save(5, model, final=True, force=False), ( + "non-force same-step save must dedupe (not overwrite)" + ) + assert manager.maybe_save(5, model, final=True, force=True), ( + "force same-step save must fire" + ) + manager.close() # drain the async prune worker + + # Reload: the persisted step-5 checkpoint must hold the LATER params (7), + # i.e. the force-save overwrote the earlier (0) one. + restored = BufModel(0.0) + checkpoint_util.restore_model(os.path.join(test_dir, "model.ckpt-5"), restored) + assert torch.allclose(restored.centroids, torch.full((4, 3), 7.0)), ( + f"overwrite failed: centroids={restored.centroids.flatten().tolist()}" + ) + dist.destroy_process_group() + + class CheckpointUtilTest(unittest.TestCase): def setUp(self): if not os.path.exists("./tmp"): @@ -327,6 +373,19 @@ def test_checkpoint_manager_discovery(self): ) self.assertEqual(manager.best_checkpoint()[1], 10) + def test_force_overwrite_same_step(self): + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + p = ctx.Process( + target=_force_overwrite_worker, args=(self.test_dir, 0, 1, port) + ) + p.start() + p.join(timeout=120) + if p.is_alive(): + p.terminate() + raise RuntimeError("force-overwrite worker timed out.") + self.assertEqual(p.exitcode, 0, "force-overwrite worker failed") + def test_dist_save_restore_model(self): port = misc_util.get_free_port() procs = [] From d67ccd1f01570c00a0390a138ab0fc1888a6abc8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:22:22 +0000 Subject: [PATCH 067/129] [review] split quantizer tests by module; clarify copy=True MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit #A: residual_quantizer_test.py tested both the base ResidualQuantizer and ResidualKMeansQuantizer. Split the K-Means tests into the matching residual_kmeans_quantizer_test.py (so each module has its own test file); the base tests stay put. #B: expand the train_offline copy=True comment — the residual loop mutates x in place and the input is a view into the reservoir buffer, so it must own a fresh copy (copy=True is a single guaranteed copy vs a double-copy clone). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../modules/sid/residual_kmeans_quantizer.py | 5 +- .../sid/residual_kmeans_quantizer_test.py | 114 ++++++++++++++++++ tzrec/modules/sid/residual_quantizer_test.py | 92 -------------- 3 files changed, 118 insertions(+), 93 deletions(-) create mode 100644 tzrec/modules/sid/residual_kmeans_quantizer_test.py diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 29ad037d1..0074331da 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -187,7 +187,10 @@ def train_offline( assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # Own one contiguous float32 copy to update in place as the residual. + # The loop below mutates x in place (the residual ``x -= q``), and the + # input is a view into the caller's float32 reservoir buffer — so own a + # fresh copy (copy=True forces one even when the dtype already matches, + # avoiding the double copy a separate ``.clone()`` would add). x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not diff --git a/tzrec/modules/sid/residual_kmeans_quantizer_test.py b/tzrec/modules/sid/residual_kmeans_quantizer_test.py new file mode 100644 index 000000000..42647468e --- /dev/null +++ b/tzrec/modules/sid/residual_kmeans_quantizer_test.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.residual_kmeans_quantizer import ( + ResidualKMeansQuantizer, +) +from tzrec.modules.sid.residual_quantizer import ( + ResidualQuantizer, +) + + +class ResidualKMeansQuantizerTest(unittest.TestCase): + def test_is_subclass(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertIsInstance(rkq, ResidualQuantizer) + + def test_non_uniform_codebook_supported(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) + self.assertEqual(rkq.n_embed_list, [8, 4, 16]) + self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + + def test_forward_returns_zeros_before_fit(self) -> None: + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) + codes, quantized = rkq(torch.randn(5, 4)) + self.assertEqual(codes.shape, (5, 2)) + self.assertEqual(quantized.shape, (5, 4)) + + def test_forward_is_fx_traceable(self) -> None: + """Predict forward must FX-trace. + + torchrec's inference pipeline symbolically traces the model, so the + per-batch distance path must be free of data-dependent control flow. + """ + import torch.fx as fx + + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) + for layer in rkq.layers: # populate centroids -> is_initialized=True + layer.load_centroids_(torch.randn(8, 4)) + traced = fx.symbolic_trace(rkq) + x = torch.randn(5, 4) + c_eager, q_eager = rkq(x) + c_traced, q_traced = traced(x) + torch.testing.assert_close(c_traced, c_eager) + torch.testing.assert_close(q_traced, q_eager) + + def test_train_offline_non_uniform(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + n_embed = [8, 4, 16] + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(512, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + # Each layer fit its own K centroids; codes stay in per-layer range. + codes, _ = rkq(torch.randn(7, 4)) + self.assertEqual(codes.shape, (7, 3)) + for i, k in enumerate(n_embed): + self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) + + def test_train_offline_then_decode(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) + + codes, _ = rkq(torch.randn(5, 4)) + self.assertTrue((codes >= 0).all() and (codes < 8).all()) + recon = rkq.decode_codes(codes) # inherited from the base + self.assertEqual(recon.shape, (5, 4)) + + def test_forward_get_codes_consistent(self) -> None: + """Forward ids and get_codes both route through the shared walk.""" + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rkq = ResidualKMeansQuantizer( + embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} + ) + rkq.train_offline(torch.randn(256, 4), verbose=False) + x = torch.randn(9, 4) + fwd_ids, fwd_quant = rkq(x) + torch.testing.assert_close(rkq.get_codes(x), fwd_ids) + # forward's residual-sum equals the centroid-sum reconstruction. + torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index d23ef1cf5..c94cc545d 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,9 +14,6 @@ import torch from torch import nn -from tzrec.modules.sid.residual_kmeans_quantizer import ( - ResidualKMeansQuantizer, -) from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, @@ -145,94 +142,5 @@ def test_decode_codes_sum_and_dtype(self) -> None: self.assertEqual(recon16.dtype, torch.bfloat16) -class ResidualKMeansQuantizerTest(unittest.TestCase): - def test_is_subclass(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertIsInstance(rkq, ResidualQuantizer) - - def test_non_uniform_codebook_supported(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) - self.assertEqual(rkq.n_embed_list, [8, 4, 16]) - self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) - - def test_forward_returns_zeros_before_fit(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) - codes, quantized = rkq(torch.randn(5, 4)) - self.assertEqual(codes.shape, (5, 2)) - self.assertEqual(quantized.shape, (5, 4)) - - def test_forward_is_fx_traceable(self) -> None: - """Predict forward must FX-trace. - - torchrec's inference pipeline symbolically traces the model, so the - per-batch distance path must be free of data-dependent control flow. - """ - import torch.fx as fx - - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - for layer in rkq.layers: # populate centroids -> is_initialized=True - layer.load_centroids_(torch.randn(8, 4)) - traced = fx.symbolic_trace(rkq) - x = torch.randn(5, 4) - c_eager, q_eager = rkq(x) - c_traced, q_traced = traced(x) - torch.testing.assert_close(c_traced, c_eager) - torch.testing.assert_close(q_traced, q_eager) - - def test_train_offline_non_uniform(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - n_embed = [8, 4, 16] - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(512, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - # Each layer fit its own K centroids; codes stay in per-layer range. - codes, _ = rkq(torch.randn(7, 4)) - self.assertEqual(codes.shape, (7, 3)) - for i, k in enumerate(n_embed): - self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) - - def test_train_offline_then_decode(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - - codes, _ = rkq(torch.randn(5, 4)) - self.assertTrue((codes >= 0).all() and (codes < 8).all()) - recon = rkq.decode_codes(codes) # inherited from the base - self.assertEqual(recon.shape, (5, 4)) - - def test_forward_get_codes_consistent(self) -> None: - """Forward ids and get_codes both route through the shared walk.""" - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - x = torch.randn(9, 4) - fwd_ids, fwd_quant = rkq(x) - torch.testing.assert_close(rkq.get_codes(x), fwd_ids) - # forward's residual-sum equals the centroid-sum reconstruction. - torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) - - if __name__ == "__main__": unittest.main() From 6a736c582fda1a212b26d5410d39bf18469d58d2 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:27:25 +0000 Subject: [PATCH 068/129] [refactor] drop CheckpointManager force param; SID uses no periodic ckpts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: remove the `force` knob from maybe_save instead of threading it through main.py. SID models run with save_checkpoints_steps and save_checkpoints_epochs = 0, so no periodic save lands on the final step and the tail final=True save is never deduped away — `force` isn't needed. - checkpoint_util.maybe_save: drop `force`; dedupe is `step == _last_ckpt_step`. - main.py: call on_train_end() for its side effect (the fit); tail save is maybe_save(..., final=True). - BaseModel/SidRqkmeans.on_train_end now return None (the bool existed only to feed `force`). - Remove the now-obsolete checkpoint force-overwrite test; update on_train_end return assertions. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/main.py | 11 ++---- tzrec/models/model.py | 12 ++---- tzrec/models/sid_rqkmeans.py | 16 +++----- tzrec/models/sid_rqkmeans_test.py | 10 ++--- tzrec/utils/checkpoint_util.py | 14 ++----- tzrec/utils/checkpoint_util_test.py | 59 ----------------------------- 6 files changed, 22 insertions(+), 100 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index fe4dd1079..e0b43b329 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -516,9 +516,10 @@ def run_eval(step: int, epoch: int) -> None: lr.step() # One-shot end-of-loop hook (default no-op; e.g. SidRqkmeans fits its FAISS - # codebook here). Returns True if it mutated persistable state, forcing the - # tail save below even when the last in-loop checkpoint hit the final step. - is_ckpt_after_train = _model.on_train_end() + # codebook here). SID models run with periodic checkpointing disabled + # (save_checkpoints_steps/epochs = 0), so the tail final=True save below is + # the only checkpoint and persists whatever on_train_end produced. + _model.on_train_end() _log_train( i_step, @@ -533,9 +534,6 @@ def run_eval(step: int, epoch: int) -> None: summary_writer.close() if train_config.is_profiling: prof.stop() - # ``force`` re-fires the save past maybe_save's per-step dedupe when - # on_train_end mutated persistable state (e.g. SidRqkmeans fit its codebook) - # after the last in-loop save landed on the final step. if ckpt_manager.maybe_save( i_step, model, @@ -543,7 +541,6 @@ def run_eval(step: int, epoch: int) -> None: dataloader_state, data_timestamp=data_timestamp, final=True, - force=is_ckpt_after_train, ): run_eval(i_step, i_epoch) ckpt_manager.close() diff --git a/tzrec/models/model.py b/tzrec/models/model.py index c6b2b952c..26ec63dbc 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -150,18 +150,14 @@ def compute_train_metric(self) -> Dict[str, torch.Tensor]: metric_results[metric_name] = metric.compute() return metric_results - def on_train_end(self) -> bool: + def on_train_end(self) -> None: """Hook fired once after the train_eval loop exits. Default no-op; override for one-shot end-of-loop work (e.g. - :class:`SidRqkmeans` fits its FAISS codebook here). - - Returns: - is_ckpt_after_train (bool): whether the hook mutated state that must - be persisted, so the loop forces a final checkpoint even if one was - already saved at the last step. Default ``False``. + :class:`SidRqkmeans` fits its FAISS codebook here). The tail + ``final=True`` checkpoint persists whatever it produced. """ - return False + return def sparse_parameters( self, diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index d8fd2d677..b2188e8c3 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -196,21 +196,18 @@ def _reconstruction( return predictions.get("quantized") @torch.no_grad() - def on_train_end(self) -> bool: + def on_train_end(self) -> None: """Fit the FAISS codebook once, after the train_eval loop exits. Overrides :meth:`BaseModel.on_train_end` (called unconditionally by ``tzrec.main``). Single-process only (enforced by the world_size guard in ``__init__``): the fit runs on one process over its local reservoir, - with no cross-rank gather/broadcast. + with no cross-rank gather/broadcast. The tail ``final=True`` checkpoint + then persists the fitted codebook (SID runs with periodic checkpointing + disabled, so that save is never deduped away). An empty reservoir only happens for a pathologically tiny corpus; the - fit is then skipped and ``False`` returned. - - Returns: - is_ckpt_after_train (bool): ``True`` if the codebook was fitted - (centroids changed → force a final checkpoint), ``False`` if the - fit was skipped (empty reservoir). + fit is then skipped. """ local = self._reservoir.sample() self._reservoir.reset() @@ -220,11 +217,10 @@ def on_train_end(self) -> bool: "[SidRqkmeans.on_train_end] empty reservoir; skipping FAISS " "fit. Did the train_eval loop run?" ) - return False + return logger.info( "[SidRqkmeans.on_train_end] fitting FAISS on %d samples (D=%d)." % (local.shape[0], local.shape[1]) ) self._quantizer.train_offline(local, verbose=True) - return True diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 782991eac..c41cb1cf1 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -145,8 +145,8 @@ def test_on_train_end_runs_faiss(self) -> None: model.predict(_make_batch(B, input_dim)) self.assertGreater(model._reservoir.n_seen, 0) - # Trigger one-shot FAISS fit; a real fit must request a tail checkpoint - self.assertTrue(model.on_train_end()) + # Trigger one-shot FAISS fit. + model.on_train_end() # Reservoir should be released after the fit self.assertEqual(model._reservoir.n_seen, 0) @@ -213,7 +213,7 @@ def test_normalize_residuals_end_to_end(self) -> None: model.train() for _ in range(8): model.predict(_make_batch(B, input_dim)) - self.assertTrue(model.on_train_end()) + model.on_train_end() for layer in model._quantizer.layers: self.assertTrue(layer.is_initialized) @@ -289,8 +289,8 @@ def test_update_metric_skipped_before_fit(self) -> None: def test_on_train_end_noop_on_empty_buffer(self) -> None: """on_train_end on an empty buffer is a warned no-op.""" model = self._create_model() - # No fit happened, so no tail checkpoint is requested. - self.assertFalse(model.on_train_end()) # should not raise + model.on_train_end() # warns and returns without fitting; must not raise + self.assertFalse(model._quantizer.is_fitted) def test_init_raises_under_ddp(self) -> None: """SidRqkmeans is single-process only: world_size>1 fails fast in init.""" diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index 612cf023d..c601fd432 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -399,7 +399,6 @@ def maybe_save( epoch: Optional[int] = None, data_timestamp: float = -1.0, final: bool = False, - force: bool = False, ) -> bool: """Save a checkpoint if a configured trigger fires; return whether it did. @@ -418,15 +417,8 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: request a save unconditionally (still subject to the dedupe), - e.g. at train end. This sets ``want``; it does not bypass the - per-step dedupe — that is what ``force`` is for. - force: bypass the per-step dedupe so a wanted save fires even if this - step was already saved — e.g. when end-of-train work mutated the - model state at the already-saved final step (see ``on_train_end``). - Orthogonal to ``final``: ``force`` only relaxes the dedupe and has - no effect on its own (it still needs ``want``, which ``final`` or a - cadence trigger supplies). + final: request a save unconditionally (still subject to the per-step + dedupe), e.g. at train end. Returns: True if a checkpoint was saved. @@ -452,7 +444,7 @@ def maybe_save( ): want = True - if not want or (step == self._last_ckpt_step and not force): + if not want or step == self._last_ckpt_step: return False self._last_ckpt_step = step diff --git a/tzrec/utils/checkpoint_util_test.py b/tzrec/utils/checkpoint_util_test.py index 6f3b38757..8fc6130f4 100644 --- a/tzrec/utils/checkpoint_util_test.py +++ b/tzrec/utils/checkpoint_util_test.py @@ -171,52 +171,6 @@ def _remap_restore_worker(test_dir, rank, world_size, port, remap_file_path): shard_w_2_m2.gather(0) -def _force_overwrite_worker(test_dir, rank, world_size, port): - """force=True re-save at an already-saved step must overwrite it. - - Saves step 5 with centroids=0, then re-saves the SAME step with different - params (centroids=7): a non-force re-save dedupes (no overwrite), a force - re-save overwrites. Reloads and asserts the persisted step-5 checkpoint - holds the later params. (This is the on_train_end post-fit checkpoint path: - a periodic save at the final step, then a forced re-save of the fitted - codebook at the same step.) - """ - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - dist.init_process_group(backend="gloo") - - class BufModel(nn.Module): - def __init__(self, fill): - super().__init__() - self.register_buffer("centroids", torch.full((4, 3), float(fill))) - - manager = checkpoint_util.CheckpointManager(test_dir, keep_checkpoint_max=0) - - # Initial save at step 5 (pre-fit: centroids = 0). - assert manager.maybe_save(5, BufModel(0.0), final=True), "initial save" - - # Same step, different params: non-force dedupes; force overwrites. - model = BufModel(7.0) - assert not manager.maybe_save(5, model, final=True, force=False), ( - "non-force same-step save must dedupe (not overwrite)" - ) - assert manager.maybe_save(5, model, final=True, force=True), ( - "force same-step save must fire" - ) - manager.close() # drain the async prune worker - - # Reload: the persisted step-5 checkpoint must hold the LATER params (7), - # i.e. the force-save overwrote the earlier (0) one. - restored = BufModel(0.0) - checkpoint_util.restore_model(os.path.join(test_dir, "model.ckpt-5"), restored) - assert torch.allclose(restored.centroids, torch.full((4, 3), 7.0)), ( - f"overwrite failed: centroids={restored.centroids.flatten().tolist()}" - ) - dist.destroy_process_group() - - class CheckpointUtilTest(unittest.TestCase): def setUp(self): if not os.path.exists("./tmp"): @@ -373,19 +327,6 @@ def test_checkpoint_manager_discovery(self): ) self.assertEqual(manager.best_checkpoint()[1], 10) - def test_force_overwrite_same_step(self): - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - p = ctx.Process( - target=_force_overwrite_worker, args=(self.test_dir, 0, 1, port) - ) - p.start() - p.join(timeout=120) - if p.is_alive(): - p.terminate() - raise RuntimeError("force-overwrite worker timed out.") - self.assertEqual(p.exitcode, 0, "force-overwrite worker failed") - def test_dist_save_restore_model(self): port = misc_util.get_free_port() procs = [] From 5bc89d4e193269e2772f0ca891fbe133002ae9f1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:33:12 +0000 Subject: [PATCH 069/129] [refactor] typed FaissKmeansConfig proto; drop Struct + _coerce_proto_numbers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per maintainer review: replace the loosely-typed google.protobuf.Struct faiss_kmeans_kwargs with a strictly-typed FaissKmeansConfig message (niter, nredo, seed, max/min_points_per_centroid, spherical, verbose). Struct numbers arrive as floats and _coerce_proto_numbers heuristically int-ified them — a typed message is type-safe and removes that hack. gpu is omitted (CPU-only). SidRqkmeans builds the faiss kwargs from the typed message's set fields (ListFields), so unset fields fall back to faiss's own defaults. Updated the mock config and the test builder to the typed form. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 24 ++++---------------- tzrec/models/sid_rqkmeans_test.py | 7 +++--- tzrec/protos/models/sid_model.proto | 20 ++++++++++++---- tzrec/tests/configs/sid_rqkmeans_mock.config | 4 ++-- 4 files changed, 25 insertions(+), 30 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index b2188e8c3..17b94e1ff 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -31,22 +31,9 @@ ResidualKMeansQuantizer, ) from tzrec.protos.model_pb2 import ModelConfig -from tzrec.utils import config_util from tzrec.utils.logging_util import logger -def _coerce_proto_numbers(d: Dict) -> Dict: - """Coerce whole-valued floats back to int. - - ``Struct.number_value`` is always float, but faiss.Kmeans kwargs - (``niter``, ``seed``, ...) need ``int``. - """ - return { - k: int(v) if isinstance(v, float) and v.is_integer() else v - for k, v in d.items() - } - - class SidRqkmeans(BaseSidModel): """SID generation model using residual K-Means (FAISS-only). @@ -90,12 +77,11 @@ def __init__( cfg = self._model_config # SidRqkmeans proto message - # config_to_kwargs yields Struct numbers as floats; coerce back to int. - self._faiss_kwargs = ( - _coerce_proto_numbers(config_util.config_to_kwargs(cfg.faiss_kmeans_kwargs)) - if cfg.HasField("faiss_kmeans_kwargs") - else {} - ) + # Typed faiss kwargs: only the explicitly-set fields are forwarded, so + # unset ones fall back to faiss's own defaults (no float->int coercion). + self._faiss_kwargs = { + f.name: v for f, v in cfg.faiss_kmeans_kwargs.ListFields() + } self._quantizer = ResidualKMeansQuantizer( embed_dim=self._input_dim, diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index c41cb1cf1..db7fc6143 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -63,11 +63,10 @@ def _create_model( SID models read the item-embedding dense feature directly from the batch and do not consume feature_groups, so none is set. """ - from google.protobuf.struct_pb2 import Struct - n_embed_list = codebook if codebook is not None else [16] * n_layers - faiss_kwargs = Struct() - faiss_kwargs.update({"niter": niter, "verbose": False, "seed": 1234}) + faiss_kwargs = sid_model_pb2.FaissKmeansConfig( + niter=niter, verbose=False, seed=1234 + ) cfg = sid_model_pb2.SidRqkmeans( input_dim=input_dim, codebook=n_embed_list, diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index fdd41a22c..f6f07da2f 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -1,7 +1,19 @@ syntax = "proto2"; package tzrec.protos; -import "google/protobuf/struct.proto"; +// Strictly-typed subset of faiss.Kmeans(D, K, **kwargs) knobs. Unset fields +// fall back to faiss's own defaults (so it is safe to leave partially set). +// ``gpu`` is intentionally omitted — the fit is CPU-only (SidRqkmeans refuses +// a visible CUDA device). +message FaissKmeansConfig { + optional uint32 niter = 1; + optional uint32 nredo = 2; + optional uint32 seed = 3; + optional uint32 max_points_per_centroid = 4; + optional uint32 min_points_per_centroid = 5; + optional bool spherical = 6; + optional bool verbose = 7; +} message SidRqkmeans { // Input embedding dimension (K-Means runs directly on raw embeddings, @@ -15,10 +27,8 @@ message SidRqkmeans { repeated uint32 codebook = 3; // L2-normalize residuals before each layer. optional bool normalize_residuals = 4 [default = false]; - // Extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs) as a - // loosely-typed dict, e.g. {niter: 20, gpu: true, verbose: true, - // spherical: false, seed: 1234}. - optional google.protobuf.Struct faiss_kmeans_kwargs = 5; + // Strictly-typed extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs). + optional FaissKmeansConfig faiss_kmeans_kwargs = 5; // Target number of embeddings to reservoir-sample for the FAISS fit // (global, across all ranks). Bounds host memory regardless of corpus // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index d473dd705..0aad49cfb 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -47,8 +47,8 @@ model_config { normalize_residuals: false embedding_feature_name: "item_emb" faiss_kmeans_kwargs { - fields { key: "niter" value { number_value: 5 } } - fields { key: "seed" value { number_value: 42 } } + niter: 5 + seed: 42 } } } From feeb4af1eccd79cd3a4461e365f4ee3fc7708ec1 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 06:39:42 +0000 Subject: [PATCH 070/129] [refactor] add QuantizeLayer base; KMeansLayer -> KMeansQuantizeLayer Per maintainer review: introduce a QuantizeLayer ABC (quantize / lookup / get_codebook_embeddings) so the K-Means and (PR3) RQ-VAE vector-quantize layers share one interface and the residual quantizer drives either uniformly. - new types.py: QuantizeOutput(embeddings, ids) NamedTuple (matches the PR3 feat/sid_abstract definition for a clean merge). - kmeans.py: add QuantizeLayer(nn.Module) ABC; rename KMeansLayer -> KMeansQuantizeLayer(QuantizeLayer); replace predict() with quantize()->QuantizeOutput (incl. the uninitialized-zeros path) + lookup() + get_codebook_embeddings(). - ResidualKMeansQuantizer: _quantize_layer/_lookup_code/get_codebook_embeddings delegate to the layer's quantize/lookup/get_codebook_embeddings. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 77 ++++++++++++++----- tzrec/modules/sid/kmeans_test.py | 36 +++++---- .../modules/sid/residual_kmeans_quantizer.py | 30 +++----- tzrec/modules/sid/types.py | 28 +++++++ 4 files changed, 120 insertions(+), 51 deletions(-) create mode 100644 tzrec/modules/sid/types.py diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 11df2b65e..5701230b7 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -14,20 +14,23 @@ This module is the single home for torch-native K-Means code used by SID models: -* :class:`KMeansLayer` — per-layer centroid container used by - :class:`ResidualKMeansQuantizer`. Centroids are injected - by the FAISS backend via ``load_centroids_``; the only forward path - is ``predict``. +* :class:`QuantizeLayer` — the per-layer quantizer interface + (``quantize`` / ``lookup`` / ``get_codebook_embeddings``) shared with the + RQ-VAE backend's vector-quantize layer. +* :class:`KMeansQuantizeLayer` — the K-Means implementation: a centroid + container populated by the FAISS backend via ``load_centroids_``. * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. """ +from abc import abstractmethod from typing import Optional, Tuple import torch from torch import nn +from tzrec.modules.sid.types import QuantizeOutput from tzrec.utils.logging_util import logger @@ -162,11 +165,35 @@ def reset(self) -> None: self._n_seen = 0 -class KMeansLayer(nn.Module): +class QuantizeLayer(nn.Module): + """One quantize layer: assign inputs to a codebook and look codes up. + + Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) + and the RQ-VAE backend's vector-quantize layer, so the residual quantizer + can drive either uniformly. + """ + + @abstractmethod + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" + raise NotImplementedError + + @abstractmethod + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather codebook embeddings for ``ids``.""" + raise NotImplementedError + + @abstractmethod + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the full codebook, shape (n_clusters, D).""" + raise NotImplementedError + + +class KMeansQuantizeLayer(QuantizeLayer): """Single layer of a residual K-Means stack. Centroids are populated externally by ``load_centroids_`` (the FAISS - backend in :class:`ResidualKMeansQuantizer`); ``predict`` is the only + backend in :class:`ResidualKMeansQuantizer`); ``quantize`` is the only forward path. Args: @@ -198,11 +225,7 @@ def is_initialized(self) -> bool: return self._initialized def mark_initialized_(self) -> None: - """Flag centroids populated, syncing buffer + cached mirror. - - For callers that fill ``centroids`` in place (e.g. the DDP broadcast - in :meth:`SidRqkmeans.on_train_end`) rather than via ``load_centroids_``. - """ + """Flag centroids populated, syncing buffer + cached mirror.""" self._is_initialized.fill_(True) self._initialized = True @@ -247,23 +270,39 @@ def _load_from_state_dict( self._initialized = bool(self._is_initialized.item()) if self._initialized and self.centroids.abs().sum() == 0: error_msgs.append( - f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " - "are all zero — checkpoint was likely taken mid-FAISS-fit. " - "Re-run on_train_end to produce a valid checkpoint." + f"KMeansQuantizeLayer at '{prefix}': _is_initialized=True but " + "centroids are all zero — checkpoint was likely taken " + "mid-FAISS-fit. Re-run on_train_end to produce a valid checkpoint." ) @torch.no_grad() - def predict(self, batch: torch.Tensor) -> torch.Tensor: - """Assign points to nearest centroid. + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign points to the nearest centroid and gather them. Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, so assignments match squared-L2 except at exact equidistant ties (measure zero for real embeddings), where either centroid is valid. + Before the FAISS fit (uninitialized) this returns all-zero codes + + embeddings so the residual walk stays a no-op and the model is callable. + ``temperature`` is unused (no soft assignment). Args: - batch (Tensor): data points, shape (B, D). + x (Tensor): data points, shape (B, D). + temperature (float): unused. Returns: - Tensor: cluster indices, shape (B,). + QuantizeOutput: ``ids`` (B,) and ``embeddings`` (B, D). """ - return torch.cdist(batch, self.centroids).argmin(dim=-1) + if not self.is_initialized: + ids = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) + return QuantizeOutput(embeddings=torch.zeros_like(x), ids=ids) + ids = torch.cdist(x, self.centroids).argmin(dim=-1) + return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather centroids for ``ids``, shape (..., D).""" + return self.centroids[ids] + + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the centroid table, shape (n_clusters, n_features).""" + return self.centroids diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py index d6b06a7f1..66a8de1a9 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_test.py @@ -14,7 +14,7 @@ import torch from tzrec.modules.sid.kmeans import ( - KMeansLayer, + KMeansQuantizeLayer, ReservoirSampler, recon_diagnostics, ) @@ -30,42 +30,52 @@ def test_recon_diagnostics_zero_on_identity(self) -> None: self.assertAlmostEqual(rel.item(), 0.0, places=6) -class KMeansLayerTest(unittest.TestCase): - """Tests for the single KMeansLayer.""" +class KMeansQuantizeLayerTest(unittest.TestCase): + """Tests for the single KMeansQuantizeLayer.""" def test_uninitialized_by_default(self) -> None: - layer = KMeansLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) self.assertFalse(layer.is_initialized) self.assertEqual(layer.centroids.abs().sum().item(), 0.0) - def test_load_centroids_and_predict(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + def test_load_centroids_and_quantize(self) -> None: + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) layer.load_centroids_(centroids) self.assertTrue(layer.is_initialized) batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) - codes = layer.predict(batch) - torch.testing.assert_close(codes, torch.tensor([0, 1])) + out = layer.quantize(batch) + torch.testing.assert_close(out.ids, torch.tensor([0, 1])) + # embeddings are the gathered centroids; lookup matches. + torch.testing.assert_close(out.embeddings, centroids[out.ids]) + torch.testing.assert_close(layer.lookup(out.ids), out.embeddings) + + def test_quantize_uninitialized_returns_zeros(self) -> None: + layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + out = layer.quantize(torch.randn(5, 3)) + self.assertEqual(out.ids.shape, (5,)) + self.assertEqual(int(out.ids.abs().sum()), 0) + torch.testing.assert_close(out.embeddings, torch.zeros(5, 3)) def test_load_centroids_shape_mismatch_raises(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) with self.assertRaises(AssertionError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) sd = layer.state_dict() # Simulate a mid-fit checkpoint: flag True but centroids still zero. sd["_is_initialized"] = torch.tensor(True) - fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): fresh.load_state_dict(sd) def test_post_fit_checkpoint_round_trips(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) - fresh = KMeansLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) fresh.load_state_dict(layer.state_dict()) self.assertTrue(fresh.is_initialized) torch.testing.assert_close(fresh.centroids, layer.centroids) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 0074331da..2b9f522c6 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import KMeansLayer, recon_diagnostics +from tzrec.modules.sid.kmeans import KMeansQuantizeLayer, recon_diagnostics from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -35,8 +35,7 @@ class ResidualKMeansQuantizer(ResidualQuantizer): residual_0 = input for each layer i: (optionally) residual_i = L2_normalize(residual_i) - code_i = layer_i.predict(residual_i) - quantized_i = layer_i.centroids[code_i] + code_i, quantized_i = layer_i.quantize(residual_i) residual_{i+1} = residual_i - quantized_i output = sum of all quantized_i @@ -72,7 +71,7 @@ def __init__( self.layers = nn.ModuleList( [ - KMeansLayer( + KMeansQuantizeLayer( n_clusters=self.n_embed_list[i], n_features=embed_dim, ) @@ -86,29 +85,22 @@ def _quantize_layer( residual: torch.Tensor, temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Nearest-centroid assignment for one layer. + """Nearest-centroid assignment for one layer (delegates to the layer). Uninitialized layers (before ``train_offline``) return zeros, so the - residual walk is a no-op and the model stays callable. ``temperature`` - is unused (no soft assignment). + residual walk is a no-op and the model stays callable. Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): unused. + temperature (float): unused (no soft assignment). Returns: codes (Tensor): cluster indices, shape (B,). quantized (Tensor): selected centroids, shape (B, D). """ - layer = self.layers[layer_idx] - if not layer.is_initialized: - codes = torch.zeros( - residual.shape[0], dtype=torch.long, device=residual.device - ) - return codes, torch.zeros_like(residual) - codes = layer.predict(residual) - return codes, layer.centroids[codes] + out = self.layers[layer_idx].quantize(residual, temperature) + return out.ids, out.embeddings def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Assign codes per layer and sum the centroids. @@ -146,11 +138,11 @@ def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: Returns: Tensor: centroids, shape (n_embed, embed_dim). """ - return self.layers[layer_idx].centroids + return self.layers[layer_idx].get_codebook_embeddings() def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's centroid table.""" - return self.layers[layer_idx].centroids[code_idx] + return self.layers[layer_idx].lookup(code_idx) def default_fit_sample_size(self) -> int: """Points the FAISS fit subsamples to: max(K) * max_points_per_centroid. @@ -195,7 +187,7 @@ def train_offline( N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the - # all-zero poison guard in KMeansLayer would not catch. + # all-zero poison guard in KMeansQuantizeLayer would not catch. max_k = max(self.n_embed_list) assert N >= max_k, ( f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py new file mode 100644 index 000000000..2f0cf3c60 --- /dev/null +++ b/tzrec/modules/sid/types.py @@ -0,0 +1,28 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data types for SID generation: output tuples shared across quantizers.""" + +from typing import NamedTuple + +import torch + + +class QuantizeOutput(NamedTuple): + """One quantize layer's output. + + Attributes: + embeddings (Tensor): quantized embeddings, shape (B, D). + ids (Tensor): codebook indices, shape (B,). + """ + + embeddings: torch.Tensor + ids: torch.Tensor From a5d43b2e4640a0cabf0793caf65021da2b40c295 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:13:57 +0000 Subject: [PATCH 071/129] [refactor] unify reconstruction key to x_hat; drop _reconstruction hook MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit predict now exposes the reconstruction under predictions["x_hat"] (the same key RQ-VAE uses) instead of "quantized", and only in eval once the codebook is fit. With the key and readiness decided by the producer, BaseSidModel.update_metric is fully concrete — it gates on `"x_hat" in predictions` and needs no per-model _reconstruction hook (removed). RQ-VAE reuses update_metric as-is (it already emits x_hat); SidRqkmeans just gates the x_hat exposure on _quantizer.is_fitted. Update the predict-contract test to assert {codes, x_hat}. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 36 ++++++++++--------------------- tzrec/models/sid_rqkmeans.py | 28 ++++++------------------ tzrec/models/sid_rqkmeans_test.py | 10 ++++----- 3 files changed, 23 insertions(+), 51 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 51fd9a179..c0a0e9e56 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -40,9 +40,9 @@ class BaseSidModel(BaseModel): proxy). Subclasses build their quantizer in ``__init__`` (after calling - ``super().__init__``) and implement :meth:`predict`, :meth:`loss`, and - :meth:`_reconstruction` (which exposes the model's reconstruction of the - input embedding for the shared :meth:`update_metric`). + ``super().__init__``) and implement :meth:`predict` and :meth:`loss`. + :meth:`predict` exposes the reconstruction under ``predictions["x_hat"]`` + (only when meaningful) so the shared :meth:`update_metric` can score it. (:meth:`update_train_metric` defaults to a no-op.) Args: @@ -115,42 +115,28 @@ def init_metric(self) -> None: self._metric_modules["rel_loss"] = RelativeL1() self._metric_modules["unique_sid_ratio"] = UniqueRatio() - def _reconstruction( - self, predictions: Dict[str, torch.Tensor] - ) -> Optional[torch.Tensor]: - """The model's reconstruction of the input embedding, or None. - - Returns the (B, D) tensor that ``mse``/``rel_loss`` compare against the - input embedding — e.g. ``predictions["quantized"]`` (RQ-KMeans) or - ``predictions["x_hat"]`` (RQ-VAE). Returns None when it is unavailable or - not yet meaningful this step (e.g. before a K-Means fit), in which case - :meth:`update_metric` skips the eval metrics entirely. - - Args: - predictions (dict): a dict of predicted result. - """ - raise NotImplementedError - def update_metric( self, predictions: Dict[str, torch.Tensor], batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - """Update eval metrics from a reconstruction + the re-extracted input. + """Update eval metrics from the reconstruction + the re-extracted input. - The target embedding is re-extracted from ``batch`` (it is an input, not - a model output). All three metrics are gated on a non-None - :meth:`_reconstruction` so a not-yet-fitted model does not log garbage. + ``predictions["x_hat"]`` is the model's reconstruction of the input + embedding (the centroid sum for RQ-KMeans, the decoder output for + RQ-VAE). Subclasses expose it only when it is meaningful, so a + not-yet-fitted model omits it and this logs nothing. The target + embedding is re-extracted from ``batch`` (it is an input, not an output). Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ - recon = self._reconstruction(predictions) - if recon is None: + if "x_hat" not in predictions: return + recon = predictions["x_hat"] embedding = self._extract_feature(batch) self._metric_modules["mse"].update(recon, embedding) self._metric_modules["rel_loss"].update(recon, embedding) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 17b94e1ff..88568c051 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -140,8 +140,13 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - if self.is_eval: - predictions["quantized"] = quantized + # Expose the centroid-sum reconstruction (``x_hat``, the scoring target + # for update_metric) only in eval AND once the codebook is fit — before + # on_train_end it is all-zeros, so omitting it makes update_metric skip. + # (Meaningful only with normalize_residuals=False; with normalization the + # centroids live on the rescaled-residual scale, off the input's scale.) + if self.is_eval and self._quantizer.is_fitted: + predictions["x_hat"] = quantized return predictions @@ -162,25 +167,6 @@ def loss( """ return {"dummy_loss": self._dummy_param.sum() * 0.0} - def _reconstruction( - self, predictions: Dict[str, torch.Tensor] - ) -> Optional[torch.Tensor]: - """Centroid-sum reconstruction, or None until the codebook is fit. - - ``quantized`` is present only in eval and is all-zeros before the - end-of-train FAISS fit, so gate on the fit — the shared - :meth:`BaseSidModel.update_metric` then skips the eval metrics until the - reconstruction is meaningful. (Meaningful only with - ``normalize_residuals=False``; with normalization the centroids live on - the rescaled-residual scale, so the two quantities don't share a scale.) - - Args: - predictions (dict): a dict of predicted result. - """ - if not self._quantizer.is_fitted: - return None - return predictions.get("quantized") - @torch.no_grad() def on_train_end(self) -> None: """Fit the FAISS codebook once, after the train_eval loop exits. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index db7fc6143..ecc96db86 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -223,7 +223,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval exposes codes + quantized only; inference is codes-only.""" + """Eval (post-fit) exposes codes + x_hat; inference is codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -236,12 +236,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode: the centroid-sum reconstruction is exposed for - # update_metric; the input embedding is NOT threaded through - # predictions (it is re-extracted from the batch in update_metric). + # Eval mode (fitted): the reconstruction is exposed as ``x_hat`` for + # update_metric; the input embedding is re-extracted from the batch + # there, not threaded through predictions. model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertEqual(set(eval_preds.keys()), {"codes", "quantized"}) + self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) From c4c361a96253ef97e9de9fedc7f04d60e3fd6ac6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:16:47 +0000 Subject: [PATCH 072/129] [style] SID: trim redundant comments Tighten a few over-long/obvious comments in predict (x_hat exposure, train branch), drop the `cfg = self._model_config` inline note, and shorten the host-tensor assert comment in train_offline. Comments only; load-bearing "why" notes kept. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 13 +++++-------- tzrec/modules/sid/residual_kmeans_quantizer.py | 5 ++--- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 88568c051..28c8a16bc 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -75,7 +75,7 @@ def __init__( "Launch with --nproc-per-node=1." ) - cfg = self._model_config # SidRqkmeans proto message + cfg = self._model_config # Typed faiss kwargs: only the explicitly-set fields are forwarded, so # unset ones fall back to faiss's own defaults (no float->int coercion). @@ -123,8 +123,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """ embedding = self._extract_feature(batch) - # Training: just reservoir-sample for the end-of-loop FAISS fit and - # return dummy codes — the codebook does not exist yet. + # Training: reservoir-sample only; codes are dummy until the fit. if self.is_train: self._reservoir.add(embedding) B = embedding.shape[0] @@ -140,11 +139,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - # Expose the centroid-sum reconstruction (``x_hat``, the scoring target - # for update_metric) only in eval AND once the codebook is fit — before - # on_train_end it is all-zeros, so omitting it makes update_metric skip. - # (Meaningful only with normalize_residuals=False; with normalization the - # centroids live on the rescaled-residual scale, off the input's scale.) + # Expose the centroid-sum reconstruction (``x_hat``) for update_metric + # only once fitted — pre-fit it is all-zeros, so omitting it skips the + # eval metrics. (Meaningful only with normalize_residuals=False.) if self.is_eval and self._quantizer.is_fitted: predictions["x_hat"] = quantized diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 2b9f522c6..3ed9d7087 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -172,9 +172,8 @@ def train_offline( owned float32 tensor; not mutated. verbose (bool): print per-layer reconstruction loss. Default: True. """ - # CPU-only: SidRqkmeans refuses to init when CUDA is visible, but this - # quantizer is a standalone module — assert the host-tensor contract it - # relies on so misuse fails here, not deep inside faiss. + # Assert the host-tensor contract locally (this is a standalone module) + # so misuse fails here, not deep inside faiss. assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" From db7f2beb8e726ba116b89fca8eff55c8e51c482b Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:21:21 +0000 Subject: [PATCH 073/129] [refactor] QuantizeLayer: make lookup concrete in the base MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lookup(ids) is backend-independent given the codebook, so define it once in QuantizeLayer as get_codebook_embeddings()[ids] and drop KMeansQuantizeLayer's override. get_codebook_embeddings stays abstract — the codebook lives in a backend-specific attribute (centroids buffer vs nn.Embedding), so only it (and quantize) need a per-subclass implementation. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 5701230b7..9dc15fb55 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -178,14 +178,18 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" raise NotImplementedError - @abstractmethod def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather codebook embeddings for ``ids``.""" - raise NotImplementedError + """Gather codebook embeddings for ``ids`` (indexes the codebook).""" + return self.get_codebook_embeddings()[ids] @abstractmethod def get_codebook_embeddings(self) -> torch.Tensor: - """Return the full codebook, shape (n_clusters, D).""" + """Return the full codebook, shape (n_clusters, D). + + The codebook lives in a backend-specific attribute (a ``centroids`` + buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays + abstract; :meth:`lookup` is then concrete in terms of it. + """ raise NotImplementedError @@ -299,10 +303,6 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: ids = torch.cdist(x, self.centroids).argmin(dim=-1) return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) - def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather centroids for ``ids``, shape (..., D).""" - return self.centroids[ids] - def get_codebook_embeddings(self) -> torch.Tensor: """Return the centroid table, shape (n_clusters, n_features).""" return self.centroids From ed12cff7cee38739dac6118d5ed51256d30d235a Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 07:24:37 +0000 Subject: [PATCH 074/129] [refactor] QuantizeLayer: own n_clusters/n_features in the base Every quantize layer has a codebook of n_clusters x n_features, so store that shape in QuantizeLayer.__init__; KMeansQuantizeLayer passes them via super() and builds its centroid buffer from them. (PR3's vector-quantize layer maps its n_embed/embed_dim onto the same base params.) Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py index 9dc15fb55..4f6450388 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans.py @@ -170,9 +170,19 @@ class QuantizeLayer(nn.Module): Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) and the RQ-VAE backend's vector-quantize layer, so the residual quantizer - can drive either uniformly. + can drive either uniformly. Owns the codebook shape; subclasses build the + backend-specific codebook (a buffer, an ``nn.Embedding``, …) from it. + + Args: + n_clusters (int): number of codebook entries. + n_features (int): feature dimension. """ + def __init__(self, n_clusters: int, n_features: int) -> None: + super().__init__() + self.n_clusters = n_clusters + self.n_features = n_features + @abstractmethod def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" @@ -210,10 +220,7 @@ def __init__( n_clusters: int, n_features: int, ) -> None: - super().__init__() - self.n_clusters = n_clusters - self.n_features = n_features - + super().__init__(n_clusters, n_features) self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) # Persistent so a post-fit checkpoint round-trips; a mid-fit poison # (True flag + zero centroids) is caught in _load_from_state_dict. From d2697eb7a140de8c3b7a048e4b87ba3a362ed24d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 08:00:03 +0000 Subject: [PATCH 075/129] [refactor] SID: extract QuantizeLayer ABC; rename kmeans -> kmeans_quantize - Add tzrec/modules/sid/quantize_layer.py: QuantizeLayer ABC shared by the K-Means backend and (PR3) the RQ-VAE vector-quantize layer. Owns the codebook shape (n_embed, embed_dim); concrete lookup() in terms of an abstract get_codebook_embeddings(). Adds quantize_layer_test.py. - Rename kmeans.py -> kmeans_quantize.py (parallel to vector_quantize.py) and CentroidQuantizeLayer -> KMeansQuantizeLayer; KMeansQuantizeLayer now subclasses QuantizeLayer. - ResidualKMeansQuantizer.train_offline now CONSUMES its input (may mutate in place); the copy decision is the caller's. on_train_end hands over the reservoir buffer by ownership (no copy) since nothing reads it afterward. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans.py | 5 +- .../sid/{kmeans.py => kmeans_quantize.py} | 71 +++------------- ...kmeans_test.py => kmeans_quantize_test.py} | 18 ++--- tzrec/modules/sid/quantize_layer.py | 58 +++++++++++++ tzrec/modules/sid/quantize_layer_test.py | 81 +++++++++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 22 ++--- 6 files changed, 177 insertions(+), 78 deletions(-) rename tzrec/modules/sid/{kmeans.py => kmeans_quantize.py} (79%) rename tzrec/modules/sid/{kmeans_test.py => kmeans_quantize_test.py} (91%) create mode 100644 tzrec/modules/sid/quantize_layer.py create mode 100644 tzrec/modules/sid/quantize_layer_test.py diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 28c8a16bc..07ce132f9 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -26,7 +26,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel -from tzrec.modules.sid.kmeans import ReservoirSampler +from tzrec.modules.sid.kmeans_quantize import ReservoirSampler from tzrec.modules.sid.residual_kmeans_quantizer import ( ResidualKMeansQuantizer, ) @@ -178,6 +178,9 @@ def on_train_end(self) -> None: An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ + # train_offline consumes its input; we hand it the reservoir buffer + # directly (no copy) since nothing reads it after this — reset() drops + # the sampler's reference and ``local`` is the last user of the storage. local = self._reservoir.sample() self._reservoir.reset() diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans_quantize.py similarity index 79% rename from tzrec/modules/sid/kmeans.py rename to tzrec/modules/sid/kmeans_quantize.py index 4f6450388..872783893 100644 --- a/tzrec/modules/sid/kmeans.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -14,22 +14,18 @@ This module is the single home for torch-native K-Means code used by SID models: -* :class:`QuantizeLayer` — the per-layer quantizer interface - (``quantize`` / ``lookup`` / ``get_codebook_embeddings``) shared with the - RQ-VAE backend's vector-quantize layer. -* :class:`KMeansQuantizeLayer` — the K-Means implementation: a centroid - container populated by the FAISS backend via ``load_centroids_``. +* :class:`KMeansQuantizeLayer` — the K-Means :class:`QuantizeLayer`: a + centroid container populated by the FAISS backend via ``load_centroids_``. * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. """ -from abc import abstractmethod from typing import Optional, Tuple import torch -from torch import nn +from tzrec.modules.sid.quantize_layer import QuantizeLayer from tzrec.modules.sid.types import QuantizeOutput from tzrec.utils.logging_util import logger @@ -165,63 +161,22 @@ def reset(self) -> None: self._n_seen = 0 -class QuantizeLayer(nn.Module): - """One quantize layer: assign inputs to a codebook and look codes up. - - Shared interface for the K-Means backend (:class:`KMeansQuantizeLayer`) - and the RQ-VAE backend's vector-quantize layer, so the residual quantizer - can drive either uniformly. Owns the codebook shape; subclasses build the - backend-specific codebook (a buffer, an ``nn.Embedding``, …) from it. - - Args: - n_clusters (int): number of codebook entries. - n_features (int): feature dimension. - """ - - def __init__(self, n_clusters: int, n_features: int) -> None: - super().__init__() - self.n_clusters = n_clusters - self.n_features = n_features - - @abstractmethod - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: - """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" - raise NotImplementedError - - def lookup(self, ids: torch.Tensor) -> torch.Tensor: - """Gather codebook embeddings for ``ids`` (indexes the codebook).""" - return self.get_codebook_embeddings()[ids] - - @abstractmethod - def get_codebook_embeddings(self) -> torch.Tensor: - """Return the full codebook, shape (n_clusters, D). - - The codebook lives in a backend-specific attribute (a ``centroids`` - buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays - abstract; :meth:`lookup` is then concrete in terms of it. - """ - raise NotImplementedError - - class KMeansQuantizeLayer(QuantizeLayer): - """Single layer of a residual K-Means stack. + """K-Means :class:`QuantizeLayer`: a centroid codebook + nearest assignment. Centroids are populated externally by ``load_centroids_`` (the FAISS backend in :class:`ResidualKMeansQuantizer`); ``quantize`` is the only - forward path. + forward path. (The k-means *fit* lives in the quantizer; this layer just + holds the resulting centroids.) Args: - n_clusters (int): number of clusters (codebook size). - n_features (int): feature dimension. + n_embed (int): number of centroids (codebook size). + embed_dim (int): feature dimension. """ - def __init__( - self, - n_clusters: int, - n_features: int, - ) -> None: - super().__init__(n_clusters, n_features) - self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__(n_embed, embed_dim) + self.register_buffer("centroids", torch.zeros(n_embed, embed_dim)) # Persistent so a post-fit checkpoint round-trips; a mid-fit poison # (True flag + zero centroids) is caught in _load_from_state_dict. self.register_buffer("_is_initialized", torch.tensor(False)) @@ -246,7 +201,7 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: Args: centroids (Tensor): externally trained centroids, - shape (n_clusters, n_features). + shape (n_embed, embed_dim). """ assert centroids.shape == self.centroids.shape, ( f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " @@ -311,5 +266,5 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: - """Return the centroid table, shape (n_clusters, n_features).""" + """Return the centroid table, shape (n_embed, embed_dim).""" return self.centroids diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_quantize_test.py similarity index 91% rename from tzrec/modules/sid/kmeans_test.py rename to tzrec/modules/sid/kmeans_quantize_test.py index 66a8de1a9..9c2df2611 100644 --- a/tzrec/modules/sid/kmeans_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -13,7 +13,7 @@ import torch -from tzrec.modules.sid.kmeans import ( +from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, recon_diagnostics, @@ -34,12 +34,12 @@ class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" def test_uninitialized_by_default(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_embed=4, embed_dim=3) self.assertFalse(layer.is_initialized) self.assertEqual(layer.centroids.abs().sum().item(), 0.0) def test_load_centroids_and_quantize(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) layer.load_centroids_(centroids) self.assertTrue(layer.is_initialized) @@ -52,30 +52,30 @@ def test_load_centroids_and_quantize(self) -> None: torch.testing.assert_close(layer.lookup(out.ids), out.embeddings) def test_quantize_uninitialized_returns_zeros(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=4, n_features=3) + layer = KMeansQuantizeLayer(n_embed=4, embed_dim=3) out = layer.quantize(torch.randn(5, 3)) self.assertEqual(out.ids.shape, (5,)) self.assertEqual(int(out.ids.abs().sum()), 0) torch.testing.assert_close(out.embeddings, torch.zeros(5, 3)) def test_load_centroids_shape_mismatch_raises(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) with self.assertRaises(AssertionError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) sd = layer.state_dict() # Simulate a mid-fit checkpoint: flag True but centroids still zero. sd["_is_initialized"] = torch.tensor(True) - fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_embed=2, embed_dim=2) with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): fresh.load_state_dict(sd) def test_post_fit_checkpoint_round_trips(self) -> None: - layer = KMeansQuantizeLayer(n_clusters=2, n_features=2) + layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) - fresh = KMeansQuantizeLayer(n_clusters=2, n_features=2) + fresh = KMeansQuantizeLayer(n_embed=2, embed_dim=2) fresh.load_state_dict(layer.state_dict()) self.assertTrue(fresh.is_initialized) torch.testing.assert_close(fresh.centroids, layer.centroids) diff --git a/tzrec/modules/sid/quantize_layer.py b/tzrec/modules/sid/quantize_layer.py new file mode 100644 index 000000000..e7f344fda --- /dev/null +++ b/tzrec/modules/sid/quantize_layer.py @@ -0,0 +1,58 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""QuantizeLayer: the per-layer quantizer interface shared by SID backends.""" + +from abc import abstractmethod + +import torch +from torch import nn + +from tzrec.modules.sid.types import QuantizeOutput + + +class QuantizeLayer(nn.Module): + """One quantize layer: assign inputs to a codebook and look codes up. + + Shared interface for the K-Means backend + (:class:`~tzrec.modules.sid.kmeans_quantize.KMeansQuantizeLayer`) and the RQ-VAE + backend's vector-quantize layer, so the residual quantizer can drive either + uniformly. Owns the codebook shape; subclasses build the backend-specific + codebook (a buffer, an ``nn.Embedding``, …) from it. + + Args: + n_embed (int): number of codebook entries. + embed_dim (int): feature dimension. + """ + + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__() + self.n_embed = n_embed + self.embed_dim = embed_dim + + @abstractmethod + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" + raise NotImplementedError + + def lookup(self, ids: torch.Tensor) -> torch.Tensor: + """Gather codebook embeddings for ``ids`` (indexes the codebook).""" + return self.get_codebook_embeddings()[ids] + + @abstractmethod + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the full codebook, shape (n_embed, embed_dim). + + The codebook lives in a backend-specific attribute (a ``centroids`` + buffer for K-Means, an ``nn.Embedding`` for RQ-VAE), so this stays + abstract; :meth:`lookup` is then concrete in terms of it. + """ + raise NotImplementedError diff --git a/tzrec/modules/sid/quantize_layer_test.py b/tzrec/modules/sid/quantize_layer_test.py new file mode 100644 index 000000000..28eb4849b --- /dev/null +++ b/tzrec/modules/sid/quantize_layer_test.py @@ -0,0 +1,81 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.quantize_layer import QuantizeLayer +from tzrec.modules.sid.types import QuantizeOutput + + +class _StubQuantizeLayer(QuantizeLayer): + """Minimal concrete subclass: a fixed codebook, nearest-row assignment. + + Exercises the base class's concrete ``__init__`` / ``lookup`` without + pulling in a backend (FAISS / nn.Embedding). + """ + + def __init__(self, n_embed: int, embed_dim: int) -> None: + super().__init__(n_embed, embed_dim) + # A deterministic codebook so lookup/quantize are checkable by hand. + self._codebook = torch.arange(n_embed * embed_dim, dtype=torch.float32).reshape( + n_embed, embed_dim + ) + + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + dist = torch.cdist(x, self._codebook) + ids = dist.argmin(dim=-1) + return QuantizeOutput(embeddings=self.lookup(ids), ids=ids) + + def get_codebook_embeddings(self) -> torch.Tensor: + return self._codebook + + +class QuantizeLayerTest(unittest.TestCase): + """Tests for the shared QuantizeLayer base class.""" + + def test_init_stores_codebook_shape(self) -> None: + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + self.assertEqual(layer.n_embed, 4) + self.assertEqual(layer.embed_dim, 3) + + def test_lookup_gathers_codebook_rows(self) -> None: + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + ids = torch.tensor([0, 2, 3, 1]) + out = layer.lookup(ids) + torch.testing.assert_close(out, layer.get_codebook_embeddings()[ids]) + self.assertEqual(out.shape, (4, 3)) + + def test_quantize_assigns_exact_codebook_rows(self) -> None: + # Feeding codebook rows back in must recover their own indices. + layer = _StubQuantizeLayer(n_embed=4, embed_dim=3) + x = layer.get_codebook_embeddings().clone() + out = layer.quantize(x) + torch.testing.assert_close(out.ids, torch.arange(4)) + torch.testing.assert_close(out.embeddings, x) + + def test_abstract_methods_unoverridden_raise(self) -> None: + # The abstract methods are documented to raise if a subclass forgets + # to implement them; QuantizeLayer relies on nn.Module (no ABCMeta), + # so this guards that the bodies still fail loudly rather than no-op. + class _Incomplete(QuantizeLayer): + pass + + layer = _Incomplete(n_embed=2, embed_dim=2) + with self.assertRaises(NotImplementedError): + layer.get_codebook_embeddings() + with self.assertRaises(NotImplementedError): + layer.quantize(torch.zeros(1, 2)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 3ed9d7087..ddd6154f2 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import KMeansQuantizeLayer, recon_diagnostics +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, recon_diagnostics from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -72,8 +72,8 @@ def __init__( self.layers = nn.ModuleList( [ KMeansQuantizeLayer( - n_clusters=self.n_embed_list[i], - n_features=embed_dim, + n_embed=self.n_embed_list[i], + embed_dim=embed_dim, ) for i in range(n_layers) ] @@ -168,8 +168,9 @@ def train_offline( ``SEARCH_CHUNK``-sized chunks to cap peak memory. Args: - inputs (Tensor): embedding matrix (N, D) on CPU. Copied once to an - owned float32 tensor; not mutated. + inputs (Tensor): embedding matrix (N, D) on CPU. CONSUMED: the + residual pass may mutate it in place, so the caller must not + rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ # Assert the host-tensor contract locally (this is a standalone module) @@ -178,11 +179,12 @@ def train_offline( assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # The loop below mutates x in place (the residual ``x -= q``), and the - # input is a view into the caller's float32 reservoir buffer — so own a - # fresh copy (copy=True forces one even when the dtype already matches, - # avoiding the double copy a separate ``.clone()`` would add). - x = inputs.detach().to(dtype=torch.float32, copy=True).contiguous() + # train_offline CONSUMES its input: the residual loop below mutates x + # in place (``x -= q``). We only normalize dtype/layout for faiss — a + # no-op view when the input is already float32 + contiguous, so the + # mutation lands in the caller's buffer (intended; the caller copies + # first if it still needs the data). + x = inputs.detach().to(dtype=torch.float32).contiguous() N = x.shape[0] # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not # errors) when N < K and returns a degenerate codebook, which the From 097e9eb0fd0f352004247be4237cbb092e30c561 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 08:08:08 +0000 Subject: [PATCH 076/129] [docs] checkpoint_util: tighten maybe_save `final` param docstring Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/utils/checkpoint_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index c601fd432..ede4ef4de 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -417,8 +417,7 @@ def maybe_save( epoch: current epoch; enables the epoch trigger when not None. data_timestamp: this rank's consumed event-time (seconds), -1.0 if none; reconciled across workers (quorum) for the event-time trigger. - final: request a save unconditionally (still subject to the per-step - dedupe), e.g. at train end. + final: force a save (still subject to the dedupe), e.g. at train end. Returns: True if a checkpoint was saved. From a9a889c2ddff65bd907f535db57c2c949b0e8802 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 11:44:09 +0000 Subject: [PATCH 077/129] [fix] SID: review fixes + fail-fast validation; fix integration test CPU pin - sid_integration_test: force CPU with CUDA_VISIBLE_DEVICES="-1" not "" (empty is treated inconsistently across CUDA runtimes; the GPU CI runner didn't hide devices, tripping the CPU-only guard in the train_eval child). - BaseSidModel: validate codebook entries >=1 and input_dim >=1 at construction; guard feature width in _extract_feature (a (B,1) tensor would otherwise broadcast into a degenerate rank-1 codebook). assert -> raise. - residual_kmeans_quantizer / kmeans_quantize: assert -> raise for the data-corruption guards (N>=max_k, load_centroids_ shape, CPU/shape contract) so they survive python -O. - RelativeL1: float64 sum / long count to avoid float32 rounding past 2**24. - kmeans_quantize: drop the duplicate relative_l1/recon_diagnostics helpers; RelativeL1 (tzrec/metrics) is the single home of the formula. Per-layer offline-fit log now reports MSE only. - sid_rqkmeans: TODO documenting the periodic-checkpointing-disabled contract (codebook can be dropped by save dedupe otherwise). - sid_model.proto: drop stale "(global, across all ranks)" wording. - mock config: set save_checkpoints_steps/epochs = 0 (the documented convention). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 15 ++++- tzrec/models/sid_model.py | 23 +++++++- tzrec/models/sid_rqkmeans.py | 13 ++++- tzrec/modules/sid/kmeans_quantize.py | 56 +++---------------- tzrec/modules/sid/kmeans_quantize_test.py | 13 +---- .../modules/sid/residual_kmeans_quantizer.py | 34 +++++------ tzrec/protos/models/sid_model.proto | 4 +- tzrec/tests/configs/sid_rqkmeans_mock.config | 2 + tzrec/tests/sid_integration_test.py | 5 +- 9 files changed, 79 insertions(+), 86 deletions(-) diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py index 72a55c28d..5aa00f4e4 100644 --- a/tzrec/metrics/relative_l1.py +++ b/tzrec/metrics/relative_l1.py @@ -29,8 +29,17 @@ class RelativeL1(Metric): def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: super().__init__(**kwargs) self.epsilon = epsilon - self.add_state("sum_rel", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("count", default=torch.tensor(0.0), dist_reduce_fx="sum") + # float64 sum / long count: element-wise aggregation crosses 2**24 at + # only ~32K rows of a 512-dim embedding, past which float32 increments + # round (mirrors the float64 care in ``ReservoirSampler.add``). + self.add_state( + "sum_rel", + default=torch.tensor(0.0, dtype=torch.float64), + dist_reduce_fx="sum", + ) + self.add_state( + "count", default=torch.tensor(0, dtype=torch.long), dist_reduce_fx="sum" + ) def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: """Accumulate the relative-L1 error between ``preds`` and ``target``. @@ -42,7 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: rel = torch.abs(target - preds) / ( torch.maximum(torch.abs(target), torch.abs(preds)) + self.epsilon ) - self.sum_rel += rel.sum() + self.sum_rel += rel.sum().double() self.count += rel.numel() def compute(self) -> torch.Tensor: diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index c0a0e9e56..579ab6702 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -70,8 +70,17 @@ def __init__( self._input_dim = cfg.input_dim self._normalize_residuals = cfg.normalize_residuals - assert cfg.codebook, "codebook must be set, e.g. [256, 256, 256]" + if not cfg.codebook: + raise ValueError("codebook must be set, e.g. [256, 256, 256]") self._n_embed_list = list(cfg.codebook) + # Fail fast: a zero codebook entry / input_dim==0 only errors opaquely + # deep inside faiss, after the whole training pass. + if any(k < 1 for k in self._n_embed_list): + raise ValueError( + f"every codebook entry must be >= 1, got {self._n_embed_list}" + ) + if self._input_dim < 1: + raise ValueError(f"input_dim must be >= 1, got {self._input_dim}") self._n_layers = len(self._n_embed_list) def _extract_feature( @@ -87,7 +96,17 @@ def _extract_feature( if feature_name is None: feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] - return kt[feature_name] + embedding = kt[feature_name] + # Guard a misconfigured feature width: a (B, 1) tensor (raw_feature + # missing value_dim, which defaults to 1) would otherwise broadcast + # silently downstream and fit a degenerate rank-1 codebook. + if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: + raise ValueError( + f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " + f"expected (B, {self._input_dim}); check that its value_dim " + "matches the SID input_dim." + ) + return embedding def init_loss(self) -> None: """Initialize loss modules. diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 07ce132f9..317fa4779 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -58,8 +58,10 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: embeddings, reservoir, and FAISS fit all stay on the host, - # so there are no device copies. Refuse to run when CUDA is visible. + # CPU-only: training and inference both run on the host (embeddings, + # reservoir, FAISS fit, and post-fit assignment), so there are no device + # copies. v1 deliberately restricts the whole model to CPU; refuse to + # run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " @@ -175,6 +177,13 @@ def on_train_end(self) -> None: then persists the fitted codebook (SID runs with periodic checkpointing disabled, so that save is never deduped away). + TODO: the "periodic checkpointing disabled" requirement is currently a + convention, not enforced. If a user sets save_checkpoints_steps/epochs + > 0 and the last in-loop save lands on the final step, the tail save is + deduped away and the fitted codebook is silently dropped. Harden the + save logic (enforce the contract / bypass the dedupe for this save) in a + future update. + An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 872783893..6eb5b940a 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -21,7 +21,7 @@ fills during training to feed the one-shot FAISS fit. """ -from typing import Optional, Tuple +from typing import Optional import torch @@ -30,49 +30,6 @@ from tzrec.utils.logging_util import logger -def recon_diagnostics( - x: torch.Tensor, - out: torch.Tensor, - epsilon: float = 1e-4, -) -> Tuple[torch.Tensor, torch.Tensor]: - """MSE + relative-L1 reconstruction diagnostics. - - Shared by :meth:`SidRqkmeans.update_metric` and - :meth:`ResidualKMeansQuantizer.train_offline`'s per-layer log. - - Args: - x: ground-truth embedding, shape (B, D). - out: quantized reconstruction, shape (B, D). - epsilon: numerical stabilizer for the relative-L1 denominator. - - Returns: - mse: scalar ``((out - x) ** 2).mean()``. - rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. - """ - return ((out - x) ** 2).mean(), relative_l1(x, out, epsilon) - - -def relative_l1( - x: torch.Tensor, - out: torch.Tensor, - epsilon: float = 1e-4, -) -> torch.Tensor: - """Relative-L1 error ``mean(|x - out| / (max(|x|, |out|) + eps))``. - - Symmetric relative error in [0, 1] (verbatim port of OpenOneRec's - ``calc_loss``). Used standalone by :meth:`SidRqkmeans.update_metric` (which - needs only ``rel``, not the MSE :meth:`recon_diagnostics` also computes). - - Args: - x: ground-truth embedding, shape (B, D). - out: quantized reconstruction, shape (B, D). - epsilon: numerical stabilizer for the denominator. - """ - return ( - torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) - ).mean() - - class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). @@ -203,10 +160,13 @@ def load_centroids_(self, centroids: torch.Tensor) -> None: centroids (Tensor): externally trained centroids, shape (n_embed, embed_dim). """ - assert centroids.shape == self.centroids.shape, ( - f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " - f"got {tuple(centroids.shape)}" - ) + # raise (not assert): under ``python -O`` a dropped assert would let a + # (1, D) tensor broadcast-replicate into all K centroid rows silently. + if centroids.shape != self.centroids.shape: + raise RuntimeError( + f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " + f"got {tuple(centroids.shape)}" + ) self.centroids.copy_( centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) ) diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 9c2df2611..2f2883562 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,20 +16,9 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, - recon_diagnostics, ) -class KmeansHelpersTest(unittest.TestCase): - """Tests for the K-Means helper functions.""" - - def test_recon_diagnostics_zero_on_identity(self) -> None: - x = torch.randn(8, 4) - mse, rel = recon_diagnostics(x, x.clone()) - self.assertAlmostEqual(mse.item(), 0.0, places=6) - self.assertAlmostEqual(rel.item(), 0.0, places=6) - - class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" @@ -60,7 +49,7 @@ def test_quantize_uninitialized_returns_zeros(self) -> None: def test_load_centroids_shape_mismatch_raises(self) -> None: layer = KMeansQuantizeLayer(n_embed=2, embed_dim=2) - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): layer.load_centroids_(torch.zeros(3, 2)) def test_mid_fit_checkpoint_rejected(self) -> None: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index ddd6154f2..4056ca861 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -23,7 +23,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, recon_diagnostics +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -173,12 +173,15 @@ def train_offline( rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Assert the host-tensor contract locally (this is a standalone module) - # so misuse fails here, not deep inside faiss. - assert not inputs.is_cuda, "train_offline is CPU-only; got a CUDA tensor" - assert inputs.dim() == 2 and inputs.shape[1] == self.embed_dim, ( - f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" - ) + # Check the host-tensor contract locally (this is a standalone module) + # so misuse fails here, not deep inside faiss. raise (not assert): these + # guard silent data corruption and must survive ``python -O``. + if inputs.is_cuda: + raise RuntimeError("train_offline is CPU-only; got a CUDA tensor") + if inputs.dim() != 2 or inputs.shape[1] != self.embed_dim: + raise RuntimeError( + f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" + ) # train_offline CONSUMES its input: the residual loop below mutates x # in place (``x -= q``). We only normalize dtype/layout for faiss — a # no-op view when the input is already float32 + contiguous, so the @@ -190,9 +193,11 @@ def train_offline( # errors) when N < K and returns a degenerate codebook, which the # all-zero poison guard in KMeansQuantizeLayer would not catch. max_k = max(self.n_embed_list) - assert N >= max_k, ( - f"need >= {max_k} points to fit the codebook (largest layer K), got N={N}" - ) + if N < max_k: + raise RuntimeError( + f"need >= {max_k} points to fit the codebook (largest layer K), " + f"got N={N}" + ) out = torch.zeros_like(x) # x0 (original input) feeds the per-layer recon log. Without # normalization ``out + x == x0``, so it's rebuilt on the fly below and @@ -254,9 +259,6 @@ def train_offline( ) @staticmethod - def _calc_loss( - x: torch.Tensor, out: torch.Tensor, epsilon: float = 1e-4 - ) -> Dict[str, float]: - """Reconstruction loss diagnostics (MSE + relative L1).""" - loss, rel_loss = recon_diagnostics(x, out, epsilon=epsilon) - return {"loss": float(loss.item()), "rel_loss": float(rel_loss.item())} + def _calc_loss(x: torch.Tensor, out: torch.Tensor) -> Dict[str, float]: + """Per-layer reconstruction MSE for the offline-fit log.""" + return {"loss": float(((out - x) ** 2).mean().item())} diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index f6f07da2f..e51462efa 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -29,8 +29,8 @@ message SidRqkmeans { optional bool normalize_residuals = 4 [default = false]; // Strictly-typed extra kwargs forwarded to faiss.Kmeans(D, K, **kwargs). optional FaissKmeansConfig faiss_kmeans_kwargs = 5; - // Target number of embeddings to reservoir-sample for the FAISS fit - // (global, across all ranks). Bounds host memory regardless of corpus + // Target number of embeddings to reservoir-sample for the FAISS fit. + // Bounds host memory regardless of corpus // size. 0 (the default) auto-derives it as max(K) * max_points_per_centroid // (the largest per-layer codebook, for non-uniform codebooks) — exactly // what FAISS subsamples to internally (default 256), so no training points diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index 0aad49cfb..0e6dec907 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -17,6 +17,8 @@ train_config { } } num_epochs: 1 + save_checkpoints_steps: 0 + save_checkpoints_epochs: 0 } eval_config { } diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 711e69ec0..94c5216e7 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -34,8 +34,11 @@ def setUp(self): # SID models are CPU-only (refuse a visible CUDA device) and # single-process (refuse world_size > 1), so hide CUDA and pin # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. + # Use "-1", not "" — an empty CUDA_VISIBLE_DEVICES is treated + # inconsistently across CUDA runtimes (the GPU CI runner does not hide + # the devices), which trips the CPU-only guard in the train_eval child. patcher = mock.patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": "", "TEST_NPROC_PER_NODE": "1"} + os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} ) patcher.start() self.addCleanup(patcher.stop) From 3b41df9b2921ae280fd0ce92375ef680bc1e61f8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 12:22:18 +0000 Subject: [PATCH 078/129] [review] SID: doc fixes, negative tests, stronger integration assertions - CPU-only guard message recommends CUDA_VISIBLE_DEVICES="-1" (not "", which this PR found unreliable on the GPU CI runner). - Correct the train_offline comment: faiss throws (not warns) for N < K. - Add negative tests for the fail-fast guards: empty/zero codebook, input_dim<1, feature-width mismatch, and train_offline too-few-points / wrong-dim. - sid_integration_test: assert the post-fit eval reports finite mse/rel_loss/ unique_sid_ratio (rel_loss < 1.0, unique_sid_ratio > 0) so a degenerate / unfitted codebook can't keep the test green. - Trim verbose comments. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/metrics/relative_l1.py | 5 ++- tzrec/models/sid_model.py | 5 ++- tzrec/models/sid_rqkmeans.py | 23 ++++++------- tzrec/models/sid_rqkmeans_test.py | 21 ++++++++++++ .../modules/sid/residual_kmeans_quantizer.py | 19 +++++------ .../sid/residual_kmeans_quantizer_test.py | 12 +++++++ tzrec/tests/sid_integration_test.py | 32 +++++++++++++------ 7 files changed, 76 insertions(+), 41 deletions(-) diff --git a/tzrec/metrics/relative_l1.py b/tzrec/metrics/relative_l1.py index 5aa00f4e4..685307608 100644 --- a/tzrec/metrics/relative_l1.py +++ b/tzrec/metrics/relative_l1.py @@ -29,9 +29,8 @@ class RelativeL1(Metric): def __init__(self, epsilon: float = 1e-4, **kwargs) -> None: super().__init__(**kwargs) self.epsilon = epsilon - # float64 sum / long count: element-wise aggregation crosses 2**24 at - # only ~32K rows of a 512-dim embedding, past which float32 increments - # round (mirrors the float64 care in ``ReservoirSampler.add``). + # float64 sum / long count: float32 loses integer precision past 2**24 + # (~32K rows of a 512-dim embedding) under element-wise aggregation. self.add_state( "sum_rel", default=torch.tensor(0.0, dtype=torch.float64), diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 579ab6702..d3023090c 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -97,9 +97,8 @@ def _extract_feature( feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] embedding = kt[feature_name] - # Guard a misconfigured feature width: a (B, 1) tensor (raw_feature - # missing value_dim, which defaults to 1) would otherwise broadcast - # silently downstream and fit a degenerate rank-1 codebook. + # Guard a misconfigured width: a (B, 1) tensor (raw_feature missing + # value_dim) would broadcast silently into a degenerate rank-1 codebook. if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: raise ValueError( f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 317fa4779..59b05af41 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -58,14 +58,12 @@ def __init__( ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) - # CPU-only: training and inference both run on the host (embeddings, - # reservoir, FAISS fit, and post-fit assignment), so there are no device - # copies. v1 deliberately restricts the whole model to CPU; refuse to - # run when CUDA is visible. + # CPU-only: v1 restricts the whole model (train + inference) to the + # host. Refuse to run when CUDA is visible. if torch.cuda.is_available(): raise RuntimeError( "SidRqkmeans is CPU-only, but a CUDA device is visible. " - 'Run with CUDA_VISIBLE_DEVICES="" (or on a CPU-only host).' + 'Run with CUDA_VISIBLE_DEVICES="-1" (or on a CPU-only host).' ) # Single-process only: the fit runs over one process's local reservoir, @@ -177,19 +175,16 @@ def on_train_end(self) -> None: then persists the fitted codebook (SID runs with periodic checkpointing disabled, so that save is never deduped away). - TODO: the "periodic checkpointing disabled" requirement is currently a - convention, not enforced. If a user sets save_checkpoints_steps/epochs - > 0 and the last in-loop save lands on the final step, the tail save is - deduped away and the fitted codebook is silently dropped. Harden the - save logic (enforce the contract / bypass the dedupe for this save) in a - future update. + TODO: "periodic checkpointing disabled" is a convention, not enforced. + With save_checkpoints_steps/epochs > 0, a final-step in-loop save can + dedupe the tail save away, silently dropping the fitted codebook. Harden + this (enforce the contract / bypass the dedupe) in a future update. An empty reservoir only happens for a pathologically tiny corpus; the fit is then skipped. """ - # train_offline consumes its input; we hand it the reservoir buffer - # directly (no copy) since nothing reads it after this — reset() drops - # the sampler's reference and ``local`` is the last user of the storage. + # train_offline consumes its input; hand it the reservoir buffer + # directly (no copy) — nothing reads it after this. local = self._reservoir.sample() self._reservoir.reset() diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index ecc96db86..273d123a2 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -109,6 +109,27 @@ def test_init_raises_on_too_small_train_sample_size(self) -> None: with self.assertRaisesRegex(RuntimeError, "largest codebook"): self._create_model(codebook=[16, 16], train_sample_size=8) + def test_init_raises_on_empty_codebook(self) -> None: + """An empty codebook fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "codebook must be set"): + self._create_model(codebook=[]) + + def test_init_raises_on_zero_codebook_entry(self) -> None: + """A zero codebook entry fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "codebook entry must be >= 1"): + self._create_model(codebook=[16, 0]) + + def test_init_raises_on_zero_input_dim(self) -> None: + """input_dim < 1 fails fast at construction.""" + with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): + self._create_model(input_dim=0) + + def test_predict_raises_on_wrong_feature_width(self) -> None: + """A feature whose width != input_dim fails fast (missing value_dim).""" + model = self._create_model(input_dim=32) + with self.assertRaisesRegex(ValueError, "value_dim"): + model.predict(_batch_from_rows(torch.randn(8, 1))) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 4056ca861..11b06951c 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -173,25 +173,22 @@ def train_offline( rely on its contents afterward (copy first if it needs them). verbose (bool): print per-layer reconstruction loss. Default: True. """ - # Check the host-tensor contract locally (this is a standalone module) - # so misuse fails here, not deep inside faiss. raise (not assert): these - # guard silent data corruption and must survive ``python -O``. + # Host-tensor contract, checked here (not deep in faiss). raise (not + # assert): these guard data corruption and must survive ``python -O``. if inputs.is_cuda: raise RuntimeError("train_offline is CPU-only; got a CUDA tensor") if inputs.dim() != 2 or inputs.shape[1] != self.embed_dim: raise RuntimeError( f"inputs must be (N, {self.embed_dim}), got {tuple(inputs.shape)}" ) - # train_offline CONSUMES its input: the residual loop below mutates x - # in place (``x -= q``). We only normalize dtype/layout for faiss — a - # no-op view when the input is already float32 + contiguous, so the - # mutation lands in the caller's buffer (intended; the caller copies - # first if it still needs the data). + # The loop below mutates x in place (``x -= q``); the dtype/layout + # normalize is a no-op view when already float32 + contiguous, so the + # mutation lands in the caller's buffer (intended — see Args: CONSUMED). x = inputs.detach().to(dtype=torch.float32).contiguous() N = x.shape[0] - # Fail loudly on a too-small corpus: faiss.Kmeans only warns (not - # errors) when N < K and returns a degenerate codebook, which the - # all-zero poison guard in KMeansQuantizeLayer would not catch. + # Clear message before faiss's own opaque C++ throw for N < K. (The + # K <= N < K * min_points_per_centroid case, where faiss only warns and + # returns a degenerate codebook, is not guarded here.) max_k = max(self.n_embed_list) if N < max_k: raise RuntimeError( diff --git a/tzrec/modules/sid/residual_kmeans_quantizer_test.py b/tzrec/modules/sid/residual_kmeans_quantizer_test.py index 42647468e..265991143 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer_test.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer_test.py @@ -31,6 +31,18 @@ def test_non_uniform_codebook_supported(self) -> None: self.assertEqual(rkq.n_embed_list, [8, 4, 16]) self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) + def test_train_offline_raises_on_too_few_points(self) -> None: + """N < largest K fails fast (clear message before faiss's own throw).""" + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=1, n_embed=8) + with self.assertRaisesRegex(RuntimeError, "largest layer K"): + rkq.train_offline(torch.randn(4, 4), verbose=False) + + def test_train_offline_raises_on_wrong_dim(self) -> None: + """An input whose width != embed_dim fails fast.""" + rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=1, n_embed=8) + with self.assertRaisesRegex(RuntimeError, "inputs must be"): + rkq.train_offline(torch.randn(16, 8), verbose=False) + def test_forward_returns_zeros_before_fit(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 94c5216e7..0c3595414 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -10,6 +10,8 @@ # limitations under the License. import glob +import json +import math import os import shutil import tempfile @@ -31,12 +33,10 @@ def setUp(self): os.makedirs("./tmp") self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") os.chmod(self.test_dir, 0o755) - # SID models are CPU-only (refuse a visible CUDA device) and - # single-process (refuse world_size > 1), so hide CUDA and pin - # nproc=1 — the GPU CI harness otherwise defaults to GPU + nproc=2. - # Use "-1", not "" — an empty CUDA_VISIBLE_DEVICES is treated - # inconsistently across CUDA runtimes (the GPU CI runner does not hide - # the devices), which trips the CPU-only guard in the train_eval child. + # SID is CPU-only + single-process, so hide CUDA and pin nproc=1 (the + # GPU CI harness defaults to GPU + nproc=2). Use "-1", not "" — an empty + # CUDA_VISIBLE_DEVICES is treated inconsistently and the GPU CI runner + # doesn't hide the devices, tripping the CPU-only guard in the child. patcher = mock.patch.dict( os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} ) @@ -98,10 +98,22 @@ def test_sid_rqkmeans_train_eval(self): glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), "no checkpoint persisted after on_train_end", ) - self.assertTrue( - os.path.exists(os.path.join(self.test_dir, "train", "eval_result.txt")), - "no eval_result.txt produced", - ) + # A fitted codebook yields finite metrics; a degenerate/unfitted one + # never exposes x_hat -> metrics stay NaN. So assert finiteness, plus + # rel_loss < 1.0 (all-zero baseline ~ 1.0) and nonzero SID variety. + result_path = os.path.join(self.test_dir, "train", "eval_result.txt") + self.assertTrue(os.path.exists(result_path), "no eval_result.txt produced") + with open(result_path) as f: + lines = [ln for ln in f.read().splitlines() if ln.strip()] + self.assertTrue(lines, "eval_result.txt is empty") + metrics = json.loads(lines[-1]) + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue( + math.isfinite(metrics[key]), f"{key} not finite: {metrics[key]}" + ) + self.assertLess(metrics["rel_loss"], 1.0) + self.assertGreater(metrics["unique_sid_ratio"], 0.0) if __name__ == "__main__": From 5f5af01400400312f02e1e063c6efd8dfd5f0efb Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 10 Jun 2026 12:27:44 +0000 Subject: [PATCH 079/129] [review] SID: drop _extract_feature width guard (embedding width is never 1) The (B, 1) broadcast footgun isn't reachable in practice, so revert _extract_feature to the plain feature read and remove its negative test. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 11 +---------- tzrec/models/sid_rqkmeans_test.py | 6 ------ 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index d3023090c..8db468799 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -96,16 +96,7 @@ def _extract_feature( if feature_name is None: feature_name = self._embedding_feature_name kt = batch.dense_features[BASE_DATA_GROUP] - embedding = kt[feature_name] - # Guard a misconfigured width: a (B, 1) tensor (raw_feature missing - # value_dim) would broadcast silently into a degenerate rank-1 codebook. - if embedding.dim() != 2 or embedding.shape[1] != self._input_dim: - raise ValueError( - f"feature '{feature_name}' has shape {tuple(embedding.shape)}, " - f"expected (B, {self._input_dim}); check that its value_dim " - "matches the SID input_dim." - ) - return embedding + return kt[feature_name] def init_loss(self) -> None: """Initialize loss modules. diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 273d123a2..0b68fefa6 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -124,12 +124,6 @@ def test_init_raises_on_zero_input_dim(self) -> None: with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): self._create_model(input_dim=0) - def test_predict_raises_on_wrong_feature_width(self) -> None: - """A feature whose width != input_dim fails fast (missing value_dim).""" - model = self._create_model(input_dim=32) - with self.assertRaisesRegex(ValueError, "value_dim"): - model.predict(_batch_from_rows(torch.randn(8, 1))) - def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 From 43e84cadb2a254439092c919677d3757611de65d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 01:58:04 +0000 Subject: [PATCH 080/129] [fix] SID integration test: skip on CUDA, run on CPU CI The end-to-end train_eval is CPU-only (SidRqkmeans refuses a visible CUDA device). Forcing CPU on the CUDA-built GPU CI image is unreliable (the prior CUDA_VISIBLE_DEVICES="" / "-1" workarounds both still failed in the train_eval child). Skip when CUDA is available so the test runs on the CPU CI job (where it passes) and skips on the GPU runner. Keep nproc=1 for the single-process guard. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/tests/sid_integration_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 0c3595414..53f24a1d3 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -21,6 +21,7 @@ import numpy as np import pyarrow as pa import pyarrow.parquet as pq +import torch from tzrec.tests import utils from tzrec.utils import config_util @@ -33,13 +34,9 @@ def setUp(self): os.makedirs("./tmp") self.test_dir = tempfile.mkdtemp(prefix="tzrec_", dir="./tmp") os.chmod(self.test_dir, 0o755) - # SID is CPU-only + single-process, so hide CUDA and pin nproc=1 (the - # GPU CI harness defaults to GPU + nproc=2). Use "-1", not "" — an empty - # CUDA_VISIBLE_DEVICES is treated inconsistently and the GPU CI runner - # doesn't hide the devices, tripping the CPU-only guard in the child. - patcher = mock.patch.dict( - os.environ, {"CUDA_VISIBLE_DEVICES": "-1", "TEST_NPROC_PER_NODE": "1"} - ) + # SidRqkmeans is single-process; pin nproc=1 (the CI harness defaults + # to 2, which would trip the world_size>1 guard). + patcher = mock.patch.dict(os.environ, {"TEST_NPROC_PER_NODE": "1"}) patcher.start() self.addCleanup(patcher.stop) @@ -73,6 +70,11 @@ def _prepare_config(self, num_rows: int, dim: int) -> str: config_util.save_message(config, config_path) return config_path + @unittest.skipIf( + torch.cuda.is_available(), + "SidRqkmeans is CPU-only; this end-to-end test runs on the CPU CI job. " + "Forcing CPU on a CUDA-built (GPU) image is unreliable.", + ) def test_sid_rqkmeans_train_eval(self): """End-to-end train -> on_train_end FAISS fit -> checkpoint -> eval. From 85b9d40f89b9977f89ba5f1d762943960f4b0cf4 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 07:10:48 +0000 Subject: [PATCH 081/129] [refactor] RQ-VAE: build on #539 QuantizeLayer ABC; retire old kmeans.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the RQ-VAE stack onto the reviewed #539 SID foundation so the VQ and K-Means backends share one per-layer interface, and keeps RQ-VAE device-agnostic (CPU + GPU) — unlike the deliberately CPU-only SidRqkmeans. - VectorQuantize now subclasses QuantizeLayer: implements quantize() -> QuantizeOutput and get_codebook_embeddings(); forward() delegates to quantize() so standalone vq(x) still works; the base lookup() returns the raw codebook vector embedding.weight[ids]. - ResidualVectorQuantizer drives the layer through the ABC (layer.quantize / layer.lookup / get_codebook_embeddings) instead of reaching into layer.embedding directly; behavior (raw-vector accumulation, STE-on-aggregate) is unchanged. - SidRqvae drops its update_metric override; #539's BaseSidModel now scores mse + rel_loss + unique_sid_ratio off predictions["x_hat"]/["codes"]. The train-path mse override stays (RQ-VAE has a train reconstruction). - Retire modules/sid/kmeans.py (replaced by #539's kmeans_quantize.py): relocate faiss_residual_kmeans into kmeans_quantize.py (CPU fit, centroids returned on the input device — safe from a GPU-resident RQ-VAE) and _squared_euclidean_distance into vector_quantize.py (its only user); drop the now-orphaned KMeansLayer / recon_diagnostics. Tests for the two moved helpers migrate to kmeans_quantize_test.py / vector_quantize_test.py. CPU + GPU: no hard-CUDA assumptions; the only device-sensitive path is the optional FAISS kmeans_init, which fits on CPU and moves centroids to the module's device (DDP: fit on rank 0, broadcast). Sinkhorn's all_reduce works under gloo and nccl. Verified on CPU: all SID unit tests pass (quantize_layer, vector_quantize, kmeans_quantize, residual_quantizer, residual_kmeans_quantizer, relative_l1, sid_rqvae, sid_rqkmeans, residual_vector_quantizer_dist). ruff check/format clean. GPU smoke + the full sid_integration_test must run in the torchgpuv4 container (this shell's CUDA driver is too old and has a stale installed tzrec; checkpoint_util import already fails there independent of this change). --- tzrec/models/sid_rqvae.py | 24 +- tzrec/modules/sid/kmeans.py | 222 ------------------ tzrec/modules/sid/kmeans_quantize.py | 63 ++++- tzrec/modules/sid/kmeans_quantize_test.py | 22 ++ tzrec/modules/sid/kmeans_test.py | 100 -------- .../modules/sid/residual_vector_quantizer.py | 14 +- tzrec/modules/sid/vector_quantize.py | 58 +++-- tzrec/modules/sid/vector_quantize_test.py | 17 +- 8 files changed, 158 insertions(+), 362 deletions(-) delete mode 100644 tzrec/modules/sid/kmeans.py delete mode 100644 tzrec/modules/sid/kmeans_test.py diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 644e61c74..b4681a9b9 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -336,29 +336,17 @@ def update_train_metric( ) -> None: """Update train metric state. - Args: - predictions (dict): a dict of predicted result. - batch (Batch): input batch data. - """ - if "x_hat" in predictions: - embedding = self._extract_feature(batch) - self._train_metric_modules["mse"].update(predictions["x_hat"], embedding) - - def update_metric( - self, - predictions: Dict[str, torch.Tensor], - batch: Batch, - losses: Optional[Dict[str, torch.Tensor]] = None, - ) -> None: - """Update metric state. + Overrides the BaseSidModel no-op: RQ-VAE has a train-time + reconstruction (the decoder output), so it can report a train-path mse. Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. - losses (dict, optional): a dict of loss. """ if "x_hat" in predictions: embedding = self._extract_feature(batch) - self._metric_modules["mse"].update(predictions["x_hat"], embedding) + self._train_metric_modules["mse"].update(predictions["x_hat"], embedding) - self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) + # Eval metrics (mse / rel_loss / unique_sid_ratio over predictions["x_hat"] + # and ["codes"]) are handled by BaseSidModel.update_metric — SidRqvae emits + # x_hat (the decoder reconstruction) so no override is needed here. diff --git a/tzrec/modules/sid/kmeans.py b/tzrec/modules/sid/kmeans.py deleted file mode 100644 index 0b6fe4255..000000000 --- a/tzrec/modules/sid/kmeans.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""K-Means utilities for the SID-generation stack. - -This module is the single home for torch-native K-Means code used by -SID models: - -* :class:`KMeansLayer` — per-layer centroid container used by - :class:`ResidualKMeansQuantizer`. Centroids are injected - by the FAISS backend via ``load_centroids_``; the only forward path - is ``predict``. -* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by - :class:`ResidualVectorQuantizer` to warm-start the RQ-VAE codebook on the - first training batch (same FAISS backend as the offline RQ-KMeans fit). -""" - -from typing import Dict, List, Optional, Tuple - -import torch -from torch import nn - - -def recon_diagnostics( - x: torch.Tensor, - out: torch.Tensor, - epsilon: float = 1e-4, -) -> Tuple[torch.Tensor, torch.Tensor]: - """MSE + relative-L1 reconstruction diagnostics. - - Shared by :meth:`SidRqkmeans.update_metric` (which wants tensors for - ``torchmetrics.MeanMetric``) and :meth:`ResidualKMeansQuantizer.train_offline`'s - per-layer log line (which converts to Python floats via ``.item()``). - - Args: - x: ground-truth embedding, shape (B, D). - out: quantized reconstruction, shape (B, D). - epsilon: numerical stabilizer for the relative-L1 denominator. - - Returns: - mse: scalar ``((out - x) ** 2).mean()``. - rel: scalar relative-L1 ``mean(|x - out| / (max(|x|, |out|) + eps))``. - """ - mse = ((out - x) ** 2).mean() - rel = ( - torch.abs(x - out) / (torch.maximum(torch.abs(x), torch.abs(out)) + epsilon) - ).mean() - return mse, rel - - -@torch.no_grad() -def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Squared L2 distance between rows of ``x`` and ``y``. - - Args: - x (Tensor): data points, shape (N, D). - y (Tensor): centroids, shape (K, D). - - Returns: - Tensor: squared distances, shape (N, K). - - Called per-batch from :meth:`KMeansLayer.predict`, so ``N`` is the batch - size and the full (N, K) product is small. Kept branch-free (no - data-dependent chunking on ``N``) so the predict forward stays - FX-traceable: torchrec's inference pipeline symbolically traces the - model, and a ``if N <= chunk_size`` on the traced batch dim raises a - ``torch.fx`` TraceError. - """ - x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) - y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) - - -@torch.no_grad() -def faiss_residual_kmeans( - samples: torch.Tensor, - n_clusters_list: List[int], - faiss_kmeans_kwargs: Optional[Dict] = None, -) -> List[torch.Tensor]: - """Residual K-Means warm-start via FAISS, one pass per layer. - - Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned - centroid, and repeats on the residual for every layer. Used by - :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook - from the first training batch — the same FAISS backend the offline - RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - - Args: - samples (Tensor): data points, shape (N, D). - n_clusters_list (List[int]): per-layer cluster counts. - faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` - (e.g. ``{'niter': 10, 'seed': 123}``). - - Returns: - List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. - - Raises: - ImportError: if ``faiss`` is not installed. - """ - try: - import faiss - except ImportError as e: - raise ImportError( - "faiss is required for RQ-VAE kmeans_init. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - kwargs = dict(faiss_kmeans_kwargs or {}) - device = samples.device - _, D = samples.shape - # Own a contiguous fp32 numpy copy we mutate in place to form residuals. - x = samples.detach().cpu().float().numpy().copy() - - res_centers: List[torch.Tensor] = [] - for n_clusters in n_clusters_list: - kmeans = faiss.Kmeans(D, n_clusters, **kwargs) - kmeans.train(x) - centroids = kmeans.centroids.copy() # (K, D) - res_centers.append(torch.from_numpy(centroids).to(device)) - _, idx = kmeans.index.search(x, 1) - x -= centroids[idx.ravel()] # residual, in place - return res_centers - - -class KMeansLayer(nn.Module): - """Single layer of a residual K-Means stack. - - Centroids are populated externally by ``load_centroids_`` (called per - layer by the FAISS backend in :class:`ResidualKMeansQuantizer`); ``predict`` - is the only forward path. PyTorch state-dict keys are scoped by - attribute path (``layers..centroids``), so renaming the class - does not break existing checkpoints. - - Args: - n_clusters (int): number of clusters (codebook size). - n_features (int): feature dimension. - """ - - def __init__( - self, - n_clusters: int, - n_features: int, - ) -> None: - super().__init__() - self.n_clusters = n_clusters - self.n_features = n_features - - self.register_buffer("centroids", torch.zeros(n_clusters, n_features)) - # Flipped by ``load_centroids_`` after the FAISS fit. Persistent - # so a normal post-fit checkpoint round-trips; mid-fit poisoning - # (True flag + still-zero centroids) is caught in _load_from_state_dict. - self.register_buffer("_is_initialized", torch.tensor(False)) - - @property - def is_initialized(self) -> bool: - """Whether centroids have been injected via ``load_centroids_``.""" - return self._is_initialized.item() - - @torch.no_grad() - def load_centroids_(self, centroids: torch.Tensor) -> None: - """Inject offline-trained centroids. - - Args: - centroids (Tensor): externally trained centroids, - shape (n_clusters, n_features). - """ - assert centroids.shape == self.centroids.shape, ( - f"centroids shape mismatch: expected {tuple(self.centroids.shape)}, " - f"got {tuple(centroids.shape)}" - ) - self.centroids.copy_( - centroids.to(dtype=self.centroids.dtype, device=self.centroids.device) - ) - self._is_initialized.fill_(True) - - def _load_from_state_dict( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) -> None: - """Reject mid-fit-checkpoint state dicts (True flag + zero centroids).""" - super()._load_from_state_dict( - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ) - if bool(self._is_initialized.item()) and self.centroids.abs().sum() == 0: - error_msgs.append( - f"KMeansLayer at '{prefix}': _is_initialized=True but centroids " - "are all zero — checkpoint was likely taken mid-FAISS-fit. " - "Re-run on_train_end to produce a valid checkpoint." - ) - - @torch.no_grad() - def predict(self, batch: torch.Tensor) -> torch.Tensor: - """Assign points to nearest centroid. - - Args: - batch (Tensor): data points, shape (B, D). - - Returns: - Tensor: cluster indices, shape (B,). - """ - dists = _squared_euclidean_distance(batch, self.centroids) - return torch.argmin(dists, dim=-1) diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 6eb5b940a..c24778802 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -19,9 +19,14 @@ * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. +* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by + :class:`~tzrec.modules.sid.residual_vector_quantizer.ResidualVectorQuantizer` + to warm-start the RQ-VAE codebook on the first training batch (same FAISS + backend as the offline RQ-KMeans fit). Fits on CPU and returns centroids on + the input device, so it is safe to call from a GPU-resident RQ-VAE. """ -from typing import Optional +from typing import Dict, List, Optional import torch @@ -30,6 +35,62 @@ from tzrec.utils.logging_util import logger +@torch.no_grad() +def faiss_residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> List[torch.Tensor]: + """Residual K-Means warm-start via FAISS, one pass per layer. + + Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned + centroid, and repeats on the residual for every layer. Used by + :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook + from the first training batch — the same FAISS backend the offline + RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. + + Device handling (CPU + GPU): the FAISS fit is always CPU (``samples`` is + copied to host as fp32 numpy), and the returned centroids are moved back to + ``samples.device``. So an RQ-VAE training on GPU gets GPU centroids while + the fit itself stays on CPU — no faiss-gpu build required. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQ-VAE kmeans_init. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + device = samples.device + _, D = samples.shape + # Own a contiguous fp32 numpy copy we mutate in place to form residuals. + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] + for n_clusters in n_clusters_list: + kmeans = faiss.Kmeans(D, n_clusters, **kwargs) + kmeans.train(x) + centroids = kmeans.centroids.copy() # (K, D) + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = kmeans.index.search(x, 1) + x -= centroids[idx.ravel()] # residual, in place + return res_centers + + class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 2f2883562..903765ef9 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,9 +16,31 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, + faiss_residual_kmeans, ) +class FaissResidualKmeansTest(unittest.TestCase): + """Tests for the FAISS residual K-Means warm-start helper.""" + + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + # Centroids come back on the input device (CPU fit, device-preserving). + self.assertEqual(centers[0].device, samples.device) + + class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" diff --git a/tzrec/modules/sid/kmeans_test.py b/tzrec/modules/sid/kmeans_test.py deleted file mode 100644 index 8fed1f83a..000000000 --- a/tzrec/modules/sid/kmeans_test.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch - -from tzrec.modules.sid.kmeans import ( - KMeansLayer, - _squared_euclidean_distance, - faiss_residual_kmeans, - recon_diagnostics, -) - - -class KmeansHelpersTest(unittest.TestCase): - """Tests for the K-Means helper functions.""" - - def test_recon_diagnostics_zero_on_identity(self) -> None: - x = torch.randn(8, 4) - mse, rel = recon_diagnostics(x, x.clone()) - self.assertAlmostEqual(mse.item(), 0.0, places=6) - self.assertAlmostEqual(rel.item(), 0.0, places=6) - - def test_squared_euclidean_distance(self) -> None: - x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) - y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) - d = _squared_euclidean_distance(x, y) - self.assertEqual(d.shape, (2, 2)) - # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 - torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - - def test_faiss_residual_kmeans_per_layer_centers(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - samples = torch.randn(512, 6) - centers = faiss_residual_kmeans( - samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} - ) - self.assertEqual(len(centers), 2) - self.assertEqual(centers[0].shape, (8, 6)) - self.assertEqual(centers[1].shape, (4, 6)) - self.assertTrue(torch.isfinite(centers[0]).all()) - self.assertEqual(centers[0].device, samples.device) - - -class KMeansLayerTest(unittest.TestCase): - """Tests for the single KMeansLayer.""" - - def test_uninitialized_by_default(self) -> None: - layer = KMeansLayer(n_clusters=4, n_features=3) - self.assertFalse(layer.is_initialized) - self.assertEqual(layer.centroids.abs().sum().item(), 0.0) - - def test_load_centroids_and_predict(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) - centroids = torch.tensor([[0.0, 0.0], [10.0, 10.0]]) - layer.load_centroids_(centroids) - self.assertTrue(layer.is_initialized) - - batch = torch.tensor([[0.1, 0.0], [9.0, 11.0]]) - codes = layer.predict(batch) - torch.testing.assert_close(codes, torch.tensor([0, 1])) - - def test_load_centroids_shape_mismatch_raises(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) - with self.assertRaises(AssertionError): - layer.load_centroids_(torch.zeros(3, 2)) - - def test_mid_fit_checkpoint_rejected(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) - sd = layer.state_dict() - # Simulate a mid-fit checkpoint: flag True but centroids still zero. - sd["_is_initialized"] = torch.tensor(True) - fresh = KMeansLayer(n_clusters=2, n_features=2) - with self.assertRaisesRegex(RuntimeError, "mid-FAISS-fit"): - fresh.load_state_dict(sd) - - def test_post_fit_checkpoint_round_trips(self) -> None: - layer = KMeansLayer(n_clusters=2, n_features=2) - layer.load_centroids_(torch.tensor([[1.0, 2.0], [3.0, 4.0]])) - fresh = KMeansLayer(n_clusters=2, n_features=2) - fresh.load_state_dict(layer.state_dict()) - self.assertTrue(fresh.is_initialized) - torch.testing.assert_close(fresh.centroids, layer.centroids) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 20f534a36..f00c6319b 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -18,7 +18,7 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import faiss_residual_kmeans +from tzrec.modules.sid.kmeans_quantize import faiss_residual_kmeans from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.modules.sid.types import ( QuantizeForwardMode, @@ -297,8 +297,12 @@ def _quantize_layer( raw_emb (Tensor): raw codebook vectors (with grad), shape (B, D). """ layer = self.layers[layer_idx] - out = layer(residual, temperature=temperature) - return out.ids, layer.embedding(out.ids) + out = layer.quantize(residual, temperature) + # ``lookup`` (QuantizeLayer base) returns the raw codebook vector + # embedding.weight[ids] — gradient still flows into the codebook; STE is + # applied once on the aggregate in :meth:`forward`. The soft STE/Gumbel + # ``out.embeddings`` is intentionally not used in the residual walk. + return out.ids, layer.lookup(out.ids) def forward( self, @@ -363,8 +367,8 @@ def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: Returns: Tensor: codebook weights, shape (n_embed, embed_dim). """ - return self.layers[layer_idx].embedding.weight.data + return self.layers[layer_idx].get_codebook_embeddings() def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's embedding table.""" - return self.layers[layer_idx].embedding(code_idx) + return self.layers[layer_idx].lookup(code_idx) diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 16ec0d629..b8f7fc0d8 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -18,13 +18,33 @@ from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans import _squared_euclidean_distance +from tzrec.modules.sid.quantize_layer import QuantizeLayer from tzrec.modules.sid.types import ( QuantizeForwardMode, QuantizeOutput, ) +@torch.no_grad() +def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Squared L2 distance between rows of ``x`` and ``y``. + + Args: + x (Tensor): data points, shape (N, D). + y (Tensor): centroids, shape (K, D). + + Returns: + Tensor: squared distances, shape (N, K). + + Device-agnostic (pure torch); the (N, K) product is small (N is the batch + size). Kept branch-free (no data-dependent chunking on ``N``) so the + forward stays FX-traceable for torchrec's inference pipeline. + """ + x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) + y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) + + def _gumbel_softmax_sample( logits: torch.Tensor, temperature: float = 1.0, @@ -104,11 +124,14 @@ def _sinkhorn( return Q.t() # (B, K) -class VectorQuantize(nn.Module): - """Single codebook vector quantization layer. +class VectorQuantize(QuantizeLayer): + """Single codebook vector quantization layer (RQ-VAE backend). - Maps continuous input vectors to the nearest codebook entry and returns - the quantized embeddings + codebook indices. The commitment loss is + The VQ :class:`~tzrec.modules.sid.quantize_layer.QuantizeLayer`: a + gradient-trained ``nn.Embedding`` codebook, the sibling of the K-Means + backend's :class:`~tzrec.modules.sid.kmeans_quantize.KMeansQuantizeLayer`. + Maps continuous input vectors to a codebook entry and returns the quantized + embeddings + codebook indices via :meth:`quantize`. The commitment loss is computed at the residual-aggregator level by :meth:`ResidualVectorQuantizer._single_commitment_loss` over the cumulative quants (matching al_sid's ``RQBottleneck.compute_commitment_loss``); @@ -141,7 +164,7 @@ def __init__( sinkhorn_iters: int = 5, sinkhorn_epsilon: float = 10.0, ) -> None: - super().__init__() + super().__init__(n_embed=n_embed, embed_dim=embed_dim) # Sinkhorn + Gumbel-Softmax pick the code by two different rules: # `ids` come from the Sinkhorn balanced-assignment argmax, while the # Gumbel branch builds `emb` from argmax(-distances + noise) (nearest @@ -155,8 +178,7 @@ def __init__( "`emb` (nearest code), so the returned id and embedding diverge. " "Use STE with Sinkhorn, or Gumbel-Softmax without Sinkhorn." ) - self.embed_dim = embed_dim - self.n_embed = n_embed + # ``n_embed`` / ``embed_dim`` are owned by the QuantizeLayer base. self.forward_mode = forward_mode self.distance_type = distance_type self.use_sinkhorn = use_sinkhorn @@ -231,12 +253,8 @@ def _find_nearest_embedding( return ids, distances - def forward( - self, - x: torch.Tensor, - temperature: float = 1.0, - ) -> QuantizeOutput: - """Forward the vector quantization layer. + def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). Training flow: 1. compute distances (L2 or cosine) @@ -245,7 +263,9 @@ def forward( 3. compute differentiable embedding (STE or Gumbel-Softmax) Commitment loss is computed by the caller - (:meth:`ResidualVectorQuantizer._single_commitment_loss`). + (:meth:`ResidualVectorQuantizer._single_commitment_loss`). Device follows + ``x`` (and the codebook, which moves with the module), so this runs on + CPU or GPU unchanged. Args: x (Tensor): input vectors, shape (B, D). @@ -275,3 +295,11 @@ def forward( emb = self.embedding(ids) return QuantizeOutput(embeddings=emb, ids=ids) + + def forward(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + """Delegate to :meth:`quantize` so standalone ``vq(x)`` still works.""" + return self.quantize(x, temperature) + + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the codebook table, shape (n_embed, embed_dim).""" + return self.embedding.weight diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index 4df9208dc..bc2de184d 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -15,7 +15,22 @@ from parameterized import parameterized from tzrec.modules.sid.types import QuantizeForwardMode -from tzrec.modules.sid.vector_quantize import VectorQuantize +from tzrec.modules.sid.vector_quantize import ( + VectorQuantize, + _squared_euclidean_distance, +) + + +class SquaredEuclideanDistanceTest(unittest.TestCase): + """Tests for the squared-L2 distance helper used by VectorQuantize.""" + + def test_squared_euclidean_distance(self) -> None: + x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + d = _squared_euclidean_distance(x, y) + self.assertEqual(d.shape, (2, 2)) + # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) class VectorQuantizeTest(unittest.TestCase): From c838aeca2b97b3455165ea87a70f63c9a2e7f7e0 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 07:23:47 +0000 Subject: [PATCH 082/129] [simplify] SID: drop VectorQuantize.forward delegator; trim redundant comments Quality pass on the RQ-VAE refactor (no behavior change): - VectorQuantize: drop the forward() delegator. quantize() is the QuantizeLayer entry point (matching the KMeansQuantizeLayer sibling, which has no forward); the residual walk already calls layer.quantize(). The 3 test sites move from vq(x) to vq.quantize(x). - ResidualVectorQuantizer._quantize_layer: trim the 4-line inline comment that duplicated the docstring (and carried a stale embedding.weight[ids] note) to a one-liner pointing at the docstring. - SidRqvae: fold the floating end-of-class "no update_metric override" comment into update_train_metric's docstring. ruff clean; vector_quantize / residual_vector_quantizer_dist / sid_rqvae / residual_quantizer tests pass. --- tzrec/models/sid_rqvae.py | 8 ++++---- tzrec/modules/sid/residual_vector_quantizer.py | 6 ++---- tzrec/modules/sid/vector_quantize.py | 4 ---- tzrec/modules/sid/vector_quantize_test.py | 6 +++--- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index b4681a9b9..da7dd04f5 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -338,6 +338,10 @@ def update_train_metric( Overrides the BaseSidModel no-op: RQ-VAE has a train-time reconstruction (the decoder output), so it can report a train-path mse. + The eval metrics (mse / rel_loss / unique_sid_ratio over + ``predictions["x_hat"]`` and ``["codes"]``) are handled by + ``BaseSidModel.update_metric`` — SidRqvae emits ``x_hat``, so it needs + no ``update_metric`` override. Args: predictions (dict): a dict of predicted result. @@ -346,7 +350,3 @@ def update_train_metric( if "x_hat" in predictions: embedding = self._extract_feature(batch) self._train_metric_modules["mse"].update(predictions["x_hat"], embedding) - - # Eval metrics (mse / rel_loss / unique_sid_ratio over predictions["x_hat"] - # and ["codes"]) are handled by BaseSidModel.update_metric — SidRqvae emits - # x_hat (the decoder reconstruction) so no override is needed here. diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index f00c6319b..20b6a0230 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -298,10 +298,8 @@ def _quantize_layer( """ layer = self.layers[layer_idx] out = layer.quantize(residual, temperature) - # ``lookup`` (QuantizeLayer base) returns the raw codebook vector - # embedding.weight[ids] — gradient still flows into the codebook; STE is - # applied once on the aggregate in :meth:`forward`. The soft STE/Gumbel - # ``out.embeddings`` is intentionally not used in the residual walk. + # Re-look up the raw codebook vector (not the soft STE/Gumbel + # ``out.embeddings``); see the docstring for why. return out.ids, layer.lookup(out.ids) def forward( diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index b8f7fc0d8..a3becf51f 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -296,10 +296,6 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: return QuantizeOutput(embeddings=emb, ids=ids) - def forward(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: - """Delegate to :meth:`quantize` so standalone ``vq(x)`` still works.""" - return self.quantize(x, temperature) - def get_codebook_embeddings(self) -> torch.Tensor: """Return the codebook table, shape (n_embed, embed_dim).""" return self.embedding.weight diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index bc2de184d..d7a7e3fc8 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -56,7 +56,7 @@ def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: ) vq.train() x = torch.randn(5, 8, requires_grad=True) - out = vq(x) + out = vq.quantize(x) self.assertEqual(out.embeddings.shape, (5, 8)) self.assertEqual(out.ids.shape, (5,)) self.assertTrue((out.ids >= 0).all() and (out.ids < 16).all()) @@ -77,7 +77,7 @@ def test_train_forward_backward_reaches_input(self) -> None: vq = VectorQuantize(embed_dim=8, n_embed=16, use_sinkhorn=False) vq.train() x = torch.randn(5, 8, requires_grad=True) - out = vq(x) + out = vq.quantize(x) out.embeddings.sum().backward() # STE routes gradient back through x. self.assertIsNotNone(x.grad) @@ -88,7 +88,7 @@ def test_eval_forward_is_plain_lookup(self) -> None: vq = VectorQuantize(embed_dim=4, n_embed=8) vq.eval() x = torch.randn(3, 4) - out = vq(x) + out = vq.quantize(x) # In eval, emb == embedding(ids) exactly. torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) From a9b8d188ba44b7707806b5009d4f9dbb7d260aba Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 07:29:07 +0000 Subject: [PATCH 083/129] [simplify] SID: move faiss_residual_kmeans to its only user (RVQ module) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit faiss_residual_kmeans is consumed solely by ResidualVectorQuantizer's RQ-VAE warm-start; the K-Means backend (ResidualKMeansQuantizer.train_offline) has its own inline FAISS loop and never calls it. Living in kmeans_quantize.py made the VQ backend import across into the K-Means backend's module, and it didn't even fit that module's stated charter ("torch-native K-Means code" — this is FAISS). Move it to residual_vector_quantizer.py, beside its caller — mirroring _squared_euclidean_distance living in vector_quantize.py (its only user). Drops the cross-backend import and the now-unused Dict/List imports in kmeans_quantize.py. The helper's test moves to a new residual_vector_quantizer_test.py. ruff clean; kmeans_quantize / residual_vector_quantizer(+dist) / sid_rqvae tests pass. --- tzrec/modules/sid/kmeans_quantize.py | 63 +------------------ tzrec/modules/sid/kmeans_quantize_test.py | 22 ------- .../modules/sid/residual_vector_quantizer.py | 59 ++++++++++++++++- .../sid/residual_vector_quantizer_test.py | 41 ++++++++++++ 4 files changed, 99 insertions(+), 86 deletions(-) create mode 100644 tzrec/modules/sid/residual_vector_quantizer_test.py diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index c24778802..6eb5b940a 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -19,14 +19,9 @@ * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. -* :func:`faiss_residual_kmeans` — FAISS residual K-Means used by - :class:`~tzrec.modules.sid.residual_vector_quantizer.ResidualVectorQuantizer` - to warm-start the RQ-VAE codebook on the first training batch (same FAISS - backend as the offline RQ-KMeans fit). Fits on CPU and returns centroids on - the input device, so it is safe to call from a GPU-resident RQ-VAE. """ -from typing import Dict, List, Optional +from typing import Optional import torch @@ -35,62 +30,6 @@ from tzrec.utils.logging_util import logger -@torch.no_grad() -def faiss_residual_kmeans( - samples: torch.Tensor, - n_clusters_list: List[int], - faiss_kmeans_kwargs: Optional[Dict] = None, -) -> List[torch.Tensor]: - """Residual K-Means warm-start via FAISS, one pass per layer. - - Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned - centroid, and repeats on the residual for every layer. Used by - :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook - from the first training batch — the same FAISS backend the offline - RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - - Device handling (CPU + GPU): the FAISS fit is always CPU (``samples`` is - copied to host as fp32 numpy), and the returned centroids are moved back to - ``samples.device``. So an RQ-VAE training on GPU gets GPU centroids while - the fit itself stays on CPU — no faiss-gpu build required. - - Args: - samples (Tensor): data points, shape (N, D). - n_clusters_list (List[int]): per-layer cluster counts. - faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` - (e.g. ``{'niter': 10, 'seed': 123}``). - - Returns: - List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. - - Raises: - ImportError: if ``faiss`` is not installed. - """ - try: - import faiss - except ImportError as e: - raise ImportError( - "faiss is required for RQ-VAE kmeans_init. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - kwargs = dict(faiss_kmeans_kwargs or {}) - device = samples.device - _, D = samples.shape - # Own a contiguous fp32 numpy copy we mutate in place to form residuals. - x = samples.detach().cpu().float().numpy().copy() - - res_centers: List[torch.Tensor] = [] - for n_clusters in n_clusters_list: - kmeans = faiss.Kmeans(D, n_clusters, **kwargs) - kmeans.train(x) - centroids = kmeans.centroids.copy() # (K, D) - res_centers.append(torch.from_numpy(centroids).to(device)) - _, idx = kmeans.index.search(x, 1) - x -= centroids[idx.ravel()] # residual, in place - return res_centers - - class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 903765ef9..2f2883562 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,31 +16,9 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, - faiss_residual_kmeans, ) -class FaissResidualKmeansTest(unittest.TestCase): - """Tests for the FAISS residual K-Means warm-start helper.""" - - def test_faiss_residual_kmeans_per_layer_centers(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - samples = torch.randn(512, 6) - centers = faiss_residual_kmeans( - samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} - ) - self.assertEqual(len(centers), 2) - self.assertEqual(centers[0].shape, (8, 6)) - self.assertEqual(centers[1].shape, (4, 6)) - self.assertTrue(torch.isfinite(centers[0]).all()) - # Centroids come back on the input device (CPU fit, device-preserving). - self.assertEqual(centers[0].device, samples.device) - - class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 20b6a0230..1e9073d1e 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -11,14 +11,13 @@ """ResidualVectorQuantizer: multi-layer residual VQ with gradient training.""" -from typing import List, Sequence, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import torch import torch.distributed as dist from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans_quantize import faiss_residual_kmeans from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.modules.sid.types import ( QuantizeForwardMode, @@ -28,6 +27,62 @@ from tzrec.utils.logging_util import logger +@torch.no_grad() +def faiss_residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> List[torch.Tensor]: + """Residual K-Means warm-start via FAISS, one pass per layer. + + Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned + centroid, and repeats on the residual for every layer. Used by + :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook + from the first training batch — the same FAISS backend the offline + RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. + + Device handling (CPU + GPU): the FAISS fit is always CPU (``samples`` is + copied to host as fp32 numpy), and the returned centroids are moved back to + ``samples.device``. So an RQ-VAE training on GPU gets GPU centroids while + the fit itself stays on CPU — no faiss-gpu build required. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for RQ-VAE kmeans_init. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + device = samples.device + _, D = samples.shape + # Own a contiguous fp32 numpy copy we mutate in place to form residuals. + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] + for n_clusters in n_clusters_list: + kmeans = faiss.Kmeans(D, n_clusters, **kwargs) + kmeans.train(x) + centroids = kmeans.centroids.copy() # (K, D) + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = kmeans.index.search(x, 1) + x -= centroids[idx.ravel()] # residual, in place + return res_centers + + class ResidualVectorQuantizer(ResidualQuantizer): """Multi-layer residual vector quantization. diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py new file mode 100644 index 000000000..3e7034f81 --- /dev/null +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from tzrec.modules.sid.residual_vector_quantizer import faiss_residual_kmeans + + +class FaissResidualKmeansTest(unittest.TestCase): + """Tests for the FAISS residual K-Means warm-start helper.""" + + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + # Centroids come back on the input device (CPU fit, device-preserving). + self.assertEqual(centers[0].device, samples.device) + + +if __name__ == "__main__": + unittest.main() From 9194ee054fbf7f14a162413bc97db4f4a1a58184 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 07:37:30 +0000 Subject: [PATCH 084/129] [fix] RVQ.get_codebook_embeddings: detach the read-only accessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The QuantizeLayer-ABC delegation changed RVQ.get_codebook_embeddings(layer_idx) from `.embedding.weight.data` (a detached leaf) to the layer's get_codebook_embeddings() == embedding.weight (a grad-requiring nn.Parameter). The @torch.no_grad() decorator does NOT detach a directly-returned leaf, so callers of this read-only export/inspection accessor now get requires_grad=True — diverging from both the prior behavior and the K-Means sibling (whose codebook is a buffer, requires_grad=False). An external consumer doing .numpy() would raise, or an in-place write would corrupt the live codebook. Detach in the accessor. The layer-level get_codebook_embeddings() stays grad-carrying (the training lookup path needs it); only this no_grad RVQ-level inspection method detaches. Codebook training is unaffected. No in-repo caller currently hits this (only tests + the grad-fine lookup path), so it's a latent public-API fix, not an active bug. --- tzrec/modules/sid/residual_vector_quantizer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 1e9073d1e..f86b2ca6e 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -414,13 +414,19 @@ def forward( def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get codebook embedding weights for a specific layer. + Detached: the layer's ``get_codebook_embeddings`` returns the live + ``nn.Embedding.weight`` (a grad-requiring leaf, needed by the training + ``lookup`` path), but this is a read-only accessor for export/inspection, + so it returns a non-grad view — matching the K-Means sibling, whose + codebook is a buffer. + Args: layer_idx (int): index of the quantization layer. Returns: Tensor: codebook weights, shape (n_embed, embed_dim). """ - return self.layers[layer_idx].get_codebook_embeddings() + return self.layers[layer_idx].get_codebook_embeddings().detach() def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: """Look up codebook vectors via the layer's embedding table.""" From 2284e86d267eacacad57f44387a99636cafd884c Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 09:55:26 +0000 Subject: [PATCH 085/129] [fix] SID code-review: CLIP empty-mask NaN, latent_weight validation, kmeans_init default Three fixes from the full RQ-VAE code review (findings #5, #3, #2): - CLIP empty-clip-mask (#5): a batch/rank with zero clip rows had every logit column masked to float("-inf"), so cross_entropy produced NaN and a NaN gradient (0 * NaN) that survived the row mask and poisoned the encoder. Replace the -inf column fill with a finite large-negative sentinel (_MASKED_LOGIT_FILL = -1e4): same masking (exp == 0) but finite CE/grad, so a no-clip batch contributes exactly 0 with 0 gradient. Branch-free (keeps the torch.compile-friendly design); safe under fp16/bf16 autocast (within fp16 range). nan_to_num kept as a backstop for non-finite upstream logits. - latent_weight length (#3): the quantizer unpacks `w1, w2 = latent_weight`, so a config with len != 2 raised a cryptic unpack ValueError deep in ResidualVectorQuantizer. Validate in SidRqvae.__init__ with a field-named error. - kmeans_init default (#2): flip the proto default true -> false. The FAISS warm-start needs the first batch to have >= max(codebook) rows and, under DDP, a rank-0 fit failure hangs the other ranks on the centroid broadcast. Making it opt-in removes that footgun from the default path. Tests: add a finite-gradient regression for the all-recon CLIP batch and a latent_weight wrong-length test. ruff clean; clip_loss / sid_rqvae / sid_rqkmeans suites pass. --- tzrec/loss/clip_loss.py | 32 +++++++++++++++++++++-------- tzrec/loss/clip_loss_test.py | 13 ++++++++++++ tzrec/models/sid_rqvae.py | 10 ++++++++- tzrec/models/sid_rqvae_test.py | 14 +++++++++++++ tzrec/protos/models/sid_model.proto | 7 +++++-- 5 files changed, 65 insertions(+), 11 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index f6d1097d5..7fd1092a6 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -19,6 +19,15 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss +# Finite large-negative fill for masked-out (recon) logit columns. ``exp()`` of +# it underflows to 0 (same masking effect as ``-inf``), but unlike ``-inf`` it +# keeps cross-entropy and its gradient finite when a row has no valid column +# (a batch/rank with zero clip rows). With ``-inf`` such a row yields NaN and, +# crucially, a NaN gradient (``0 * NaN``) that survives the row mask and poisons +# the encoder. Finite, so a no-clip batch contributes exactly 0 with 0 gradient. +# Kept well within fp16 range (-65504) so it is safe under FP16/BF16 autocast. +_MASKED_LOGIT_FILL = -1e4 + class MaskedCLIPLoss(_Loss): """Masked CLIP loss for mixed recon+clip batches. @@ -102,10 +111,15 @@ def _masked_cross_entropy( """ ce_i = F.cross_entropy(logits_i, safe_labels, reduction="none") ce_t = F.cross_entropy(logits_t, safe_labels, reduction="none") - # NaN can occur when all logits are -inf (all-recon edge case) + # Backstop only: the finite _MASKED_LOGIT_FILL already keeps the + # all-recon row finite, so this guards solely against a non-finite + # logit arriving from upstream (e.g. an overflowed logit_scale). ce_i = torch.nan_to_num(ce_i, nan=0.0) ce_t = torch.nan_to_num(ce_t, nan=0.0) + # Row mask: only clip rows contribute; clamp(min=1) keeps a no-clip + # batch at 0 (not 0/0). Combined with the finite fill, a batch with no + # clip rows yields exactly 0 loss and 0 gradient. n_valid = clip_mask.float().sum().clamp(min=1) return ((ce_i + ce_t) * clip_mask.float()).sum() / (2 * n_valid) @@ -159,16 +173,18 @@ def forward( logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() - # --- Column mask: recon columns -> -inf (not as negatives) --- + # --- Column mask: recon columns -> large-negative (not as negatives) --- + # Finite fill (not -inf) so an all-recon row keeps a finite, non-NaN + # gradient; see _MASKED_LOGIT_FILL. clip_mask_all = self._gather_bool_mask(clip_mask) col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) - logits_img_self = logits_img_self.masked_fill(col_mask, float("-inf")) - logits_txt_self = logits_txt_self.masked_fill(col_mask, float("-inf")) - logits_img_ori = logits_img_ori.masked_fill(col_mask, float("-inf")) - logits_txt_ori = logits_txt_ori.masked_fill(col_mask, float("-inf")) - logits_img_cl = logits_img_cl.masked_fill(col_mask, float("-inf")) - logits_txt_cl = logits_txt_cl.masked_fill(col_mask, float("-inf")) + logits_img_self = logits_img_self.masked_fill(col_mask, _MASKED_LOGIT_FILL) + logits_txt_self = logits_txt_self.masked_fill(col_mask, _MASKED_LOGIT_FILL) + logits_img_ori = logits_img_ori.masked_fill(col_mask, _MASKED_LOGIT_FILL) + logits_txt_ori = logits_txt_ori.masked_fill(col_mask, _MASKED_LOGIT_FILL) + logits_img_cl = logits_img_cl.masked_fill(col_mask, _MASKED_LOGIT_FILL) + logits_txt_cl = logits_txt_cl.masked_fill(col_mask, _MASKED_LOGIT_FILL) # --- Safe labels: recon rows fallback to first clip column --- labels = self.labels diff --git a/tzrec/loss/clip_loss_test.py b/tzrec/loss/clip_loss_test.py index f124c2ff8..11e4eb2b3 100644 --- a/tzrec/loss/clip_loss_test.py +++ b/tzrec/loss/clip_loss_test.py @@ -59,6 +59,19 @@ def test_all_recon_mask_zero_loss(self) -> None: self.assertTrue(torch.isfinite(out["clip_loss"])) self.assertAlmostEqual(out["clip_loss"].item(), 0.0, places=6) + def test_all_recon_mask_finite_gradient(self) -> None: + # Regression: with float("-inf") column fill an all-recon batch produced + # a NaN gradient (0 * NaN) that survived the row mask. The finite fill + # must keep the backward finite (and zero, since no clip row contributes). + loss_fn = MaskedCLIPLoss() + feats = self._features(6, 8) + mask = torch.zeros(6, dtype=torch.bool) + loss_fn(feats, mask)["clip_loss"].backward() + grad = feats["image_embed"].grad + self.assertIsNotNone(grad) + self.assertTrue(torch.isfinite(grad).all()) + self.assertAlmostEqual(grad.abs().sum().item(), 0.0, places=6) + def test_backward_flows_to_embeddings(self) -> None: loss_fn = MaskedCLIPLoss() feats = self._features(6, 8) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index da7dd04f5..6e285cdb6 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -92,8 +92,16 @@ def __init__( list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] ) # latent_weight defaults to (1.0, 0.5) when the user leaves the - # repeated field empty. + # repeated field empty. It must be exactly [w1, w2] (encoder-side and + # codebook-side commitment weights); the quantizer unpacks it into two, + # so validate here with a field-named error instead of a cryptic + # unpack ValueError deep in ResidualVectorQuantizer. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) + if len(latent_weight) != 2: + raise ValueError( + "latent_weight must have exactly 2 values [w1, w2], got " + f"{list(cfg.latent_weight)}" + ) use_sinkhorn = True sinkhorn_iters = 5 diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 7ed681dfb..d3887be6c 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -273,6 +273,20 @@ def test_rqvae_backward(self) -> None: ) self.assertTrue(has_grad) + def test_latent_weight_wrong_length_raises(self) -> None: + """latent_weight must be exactly [w1, w2]; a bad length fails fast.""" + for bad in ([1.0], [1.0, 0.5, 0.25]): + cfg = sid_model_pb2.SidRqvae( + input_dim=32, + embed_dim=8, + codebook=[16, 16], + kmeans_init=False, + latent_weight=bad, + ) + model_config = model_pb2.ModelConfig(sid_rqvae=cfg) + with self.assertRaisesRegex(ValueError, "latent_weight"): + SidRqvae(model_config=model_config, features=[], labels=[]) + def test_clip_mask_uses_flag_not_equality(self) -> None: """The is_clip_pair flag, not bit-exact equality, drives routing. diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index de4254c17..d160a57fc 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -67,8 +67,11 @@ message SidRqvae { repeated float latent_weight = 11; // STE rotation trick. optional bool rotation_trick = 12 [default = false]; - // KMeans codebook initialization on first training forward. - optional bool kmeans_init = 13 [default = true]; + // KMeans codebook initialization on first training forward. Default false: + // the FAISS warm-start needs the first batch to have >= max(codebook) rows + // (faiss requires N >= K) and, under DDP, a rank-0 fit failure would hang + // the other ranks on the centroid broadcast — so it is opt-in. + optional bool kmeans_init = 13 [default = false]; // === Optional sub-module configs === // Sinkhorn uniform assignment. Default behavior when this block is From 12bd93e9626aa8e1db67b0417164d5b0f7deee1d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 11:27:25 +0000 Subject: [PATCH 086/129] [fix] RQ-VAE: make Gumbel-Softmax forward_mode actually functional (#1) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Gumbel was doubly broken: (a) it crashed on the default config (use_sinkhorn defaults True, which the constructor asserts against), and (b) even with Sinkhorn off it was a silent no-op — _quantize_layer discarded the soft gumbel embedding and used the hard lookup, and the assignment distances were computed under @torch.no_grad(), so no gradient reached the encoder; it trained identically to STE. Make it real: - Distances are now grad-capable (drop @torch.no_grad() from _compute_distances / _squared_euclidean_distance; the STE/Sinkhorn paths still assign under _find_nearest_embedding's own no_grad). clamp is now out-of-place (autograd-safe). - VectorQuantize.quantize: for gumbel-train, compute differentiable logits (-distances), gumbel-softmax sample, and take `ids` from the sample so the saved code matches the embedding actually used (no argmin/sample desync). - ResidualVectorQuantizer: store the mode; for gumbel-train, _quantize_layer returns the soft embedding and forward walks the LIVE input and skips the aggregate STE, so the encoder + codebook are trained through the soft assignment. STE behaviour is unchanged (detached walk + aggregate STE). - Auto-disable Sinkhorn for gumbel (with a warning) instead of crashing; warn that rotation_trick is ignored for gumbel. Tests: gumbel grad reaches encoder AND codebook via the recon path (STE gives the codebook zero recon-path grad); ids match the embedding used; default gumbel config constructs (Sinkhorn auto-off); distances differentiable. CPU-validated (autograd is device-agnostic); ruff clean; full SID suite green. --- .../modules/sid/residual_vector_quantizer.py | 63 ++++++++++++++----- .../sid/residual_vector_quantizer_test.py | 62 +++++++++++++++++- tzrec/modules/sid/vector_quantize.py | 39 +++++++----- tzrec/modules/sid/vector_quantize_test.py | 32 ++++++++++ 4 files changed, 165 insertions(+), 31 deletions(-) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index f86b2ca6e..2492ea1e3 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -161,6 +161,24 @@ def __init__( f"choose from {list(self._FORWARD_MODE_MAP.keys())}" ) mode_enum = self._FORWARD_MODE_MAP[forward_mode] + self._forward_mode = mode_enum + is_gumbel = mode_enum == QuantizeForwardMode.GUMBEL_SOFTMAX + # Gumbel drives its own assignment, so Sinkhorn is incompatible (and the + # two would desync code vs embedding). Auto-disable rather than crash on + # the proto default (use_sinkhorn omitted -> True). + if is_gumbel and use_sinkhorn: + logger.warning( + "forward_mode=gumbel_softmax is incompatible with Sinkhorn; " + "disabling use_sinkhorn for this quantizer." + ) + use_sinkhorn = False + # The aggregate STE (and its rotation-trick variant) is the STE encoder + # gradient path; Gumbel carries its own, so the rotation trick is unused. + if is_gumbel and rotation_trick: + logger.warning( + "rotation_trick has no effect with forward_mode=gumbel_softmax " + "(the aggregate STE is skipped); ignoring it." + ) if isinstance(distance_type, str): distance_types = [distance_type] * n_layers @@ -353,8 +371,12 @@ def _quantize_layer( """ layer = self.layers[layer_idx] out = layer.quantize(residual, temperature) - # Re-look up the raw codebook vector (not the soft STE/Gumbel - # ``out.embeddings``); see the docstring for why. + if self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX and self.training: + # Gumbel: the soft differentiable embedding carries gradient to both + # the encoder and the codebook, so use it directly (no aggregate STE). + return out.ids, out.embeddings + # STE / eval: re-look up the raw codebook vector (not the soft + # ``out.embeddings``); STE is applied once on the aggregate in forward. return out.ids, layer.lookup(out.ids) def forward( @@ -364,13 +386,20 @@ def forward( ) -> ResidualQuantizerOutput: """Forward the multi-layer residual quantization. - Training flow: - 1. If kmeans_init and not initialized -> init_embed_(input) - 2. Shared residual walk (:meth:`_residual_pass`) over the detached - input: per-layer assign + grad-carrying accumulation. - 3. Mean of per-layer commitment losses over the cumulative quants - (cos/l2 with latent_weight). - 4. STE gradient pass-through (or rotation trick). + Two encoder-gradient regimes by ``forward_mode``: + + - STE: the residual walk runs on the DETACHED input (the assignment is + non-differentiable), and the encoder gradient is re-attached once via + the aggregate STE in step 4. The codebook trains via the commitment + loss (the accumulated raw lookups keep grad). + - Gumbel-Softmax: the soft assignment is itself differentiable, so the + walk runs on the LIVE input and the gradient reaches the encoder and + codebook through the accumulated soft embeddings; the aggregate STE is + skipped. + + Steps: (1) kmeans_init on the first training forward; (2) residual walk; + (3) mean per-layer commitment loss over the cumulative quants; + (4) aggregate STE (STE mode only). Args: input (Tensor): input embeddings, shape (B, D). @@ -384,11 +413,15 @@ def forward( if self.training: self.init_embed_(input) - # Step 2: shared residual walk on the detached input (encoder grad - # flows only via the STE in step 4; the accumulated quants keep grad - # so the codebook still trains). cumulative[i] = sum after layer i. + is_gumbel = self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + train_gumbel = is_gumbel and self.training + + # Step 2: residual walk. Gumbel keeps the live input (its soft + # assignment carries encoder grad); STE detaches and re-attaches grad in + # step 4. cumulative[i] = sum after layer i. + walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass( - input.detach(), temperature + walk_input, temperature ) # Step 3: aggregate per-layer commitment loss @@ -396,9 +429,9 @@ def forward( torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) ) - # Step 4: STE or rotation trick (quants_trunc = final accumulated) + # Step 4: aggregate STE (STE mode only; Gumbel already carries grad). quants_trunc = aggregated_quants - if self.training: + if self.training and not is_gumbel: if self.rotation_trick: quants_trunc = self._apply_rotation_trick(input, quants_trunc) else: diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 3e7034f81..9d80bd42b 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -13,7 +13,67 @@ import torch -from tzrec.modules.sid.residual_vector_quantizer import faiss_residual_kmeans +from tzrec.modules.sid.residual_vector_quantizer import ( + ResidualVectorQuantizer, + faiss_residual_kmeans, +) + + +class GumbelResidualVQTest(unittest.TestCase): + """Gumbel-Softmax forward_mode: grad reaches encoder + codebook (not STE).""" + + def test_default_gumbel_config_disables_sinkhorn(self) -> None: + # use_sinkhorn defaults True; gumbel must auto-disable it (not crash). + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=3, + n_embed=16, + forward_mode="gumbel_softmax", + use_sinkhorn=True, + kmeans_init=False, + ) + self.assertTrue(all(not layer.use_sinkhorn for layer in rvq.layers)) + + def test_gumbel_grad_flows_via_soft_assignment(self) -> None: + # The fix: the gradient from the reconstruction path (no commitment + # loss) must reach BOTH the encoder input and the codebook through the + # soft gumbel embedding. Under the old code it reached neither (the soft + # embedding was discarded), so gumbel silently trained like STE. + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=3, + n_embed=16, + forward_mode="gumbel_softmax", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + z = torch.randn(32, 8, requires_grad=True) + rvq(z).quantized_embeddings.sum().backward() + self.assertIsNotNone(z.grad) + self.assertGreater(z.grad.abs().sum().item(), 0.0) + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertIsNotNone(cb_grad) + self.assertGreater(cb_grad.abs().sum().item(), 0.0) + + def test_ste_codebook_grad_is_detached_on_recon_path(self) -> None: + # Contrast: STE detaches the aggregate, so the recon path gives the + # codebook no gradient (it trains via the commitment loss instead). + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + z = torch.randn(16, 8, requires_grad=True) + rvq(z).quantized_embeddings.sum().backward() + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertTrue(cb_grad is None or cb_grad.abs().sum().item() == 0.0) class FaissResidualKmeansTest(unittest.TestCase): diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index a3becf51f..1b30cfb0d 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -25,7 +25,6 @@ ) -@torch.no_grad() def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Squared L2 distance between rows of ``x`` and ``y``. @@ -38,11 +37,14 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Device-agnostic (pure torch); the (N, K) product is small (N is the batch size). Kept branch-free (no data-dependent chunking on ``N``) so the - forward stays FX-traceable for torchrec's inference pipeline. + forward stays FX-traceable for torchrec's inference pipeline. NOT wrapped in + ``no_grad``: the Gumbel-Softmax path needs differentiable distances (the + STE/Sinkhorn callers run it under their own ``no_grad``). ``clamp`` is + out-of-place to stay autograd-safe. """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp_(min=0.0) + return (x_sq + y_sq - 2.0 * x @ y.t()).clamp(min=0.0) def _gumbel_softmax_sample( @@ -188,11 +190,13 @@ def __init__( self.embedding = nn.Embedding(n_embed, embed_dim) nn.init.kaiming_uniform_(self.embedding.weight) - @torch.no_grad() def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: """Compute distances between input vectors and codebook entries. - Supports L2 and cosine distance metrics. + Supports L2 and cosine distance metrics. NOT wrapped in ``no_grad``: + the Gumbel-Softmax path calls this directly and needs the gradient to + flow to the encoder; the STE/Sinkhorn path calls it inside + :meth:`_find_nearest_embedding`, which is itself ``no_grad``. Args: x (Tensor): input vectors, shape (B, D). @@ -274,18 +278,23 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: Returns: QuantizeOutput: named tuple of (embeddings, ids). """ - # Step 1-2: find nearest codebook entry - ids, distances = self._find_nearest_embedding(x) - - # Step 3: differentiable embedding. Gumbel takes a separate path - # that combines all codebook entries; STE goes through a single - # embedding lookup. + # Gumbel-Softmax: distances must be differentiable so the gradient + # reaches the encoder through the soft assignment, so they are computed + # WITH grad here (the STE/eval branch below assigns under no_grad). The + # straight-through hard sample drives BOTH the embedding and ``ids``, so + # the saved code matches the codebook vector actually reconstructed + # (unlike argmin, which the gumbel noise can disagree with). Sinkhorn is + # disabled for this mode (see ResidualVectorQuantizer.__init__). if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: - weights = _gumbel_softmax_sample( - -distances, temperature=temperature, hard=True - ) + logits = -self._compute_distances(x) # (B, n_embed), differentiable + weights = _gumbel_softmax_sample(logits, temperature=temperature, hard=True) emb = weights @ self.embedding.weight - elif self.training and self.forward_mode == QuantizeForwardMode.STE: + ids = weights.argmax(dim=-1) + return QuantizeOutput(embeddings=emb, ids=ids) + + # STE / eval: nearest-neighbour assignment under no_grad. + ids, _ = self._find_nearest_embedding(x) + if self.training and self.forward_mode == QuantizeForwardMode.STE: quantized = self.embedding(ids) # Straight-Through Estimator: gradient passes through emb = x + (quantized - x).detach() diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index d7a7e3fc8..b6a05f269 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -92,6 +92,38 @@ def test_eval_forward_is_plain_lookup(self) -> None: # In eval, emb == embedding(ids) exactly. torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) + def test_gumbel_train_ids_match_embedding(self) -> None: + # In gumbel training the saved code must index the codebook vector + # actually used (the hard sample), so emb forward == embedding(ids). + # (Under the old code ids came from argmin and could disagree with the + # gumbel-sampled embedding.) + torch.manual_seed(0) + vq = VectorQuantize( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=False, + ) + vq.train() + out = vq.quantize(torch.randn(5, 8)) + torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) + + def test_gumbel_train_distances_are_differentiable(self) -> None: + # Gumbel needs the assignment differentiable: grad must reach the input. + torch.manual_seed(0) + vq = VectorQuantize( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=False, + ) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + vq.quantize(x).embeddings.sum().backward() + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all()) + self.assertGreater(x.grad.abs().sum().item(), 0.0) + if __name__ == "__main__": unittest.main() From 15c5210fa2212d27f0d786bacb085d5de740e910 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 11:53:44 +0000 Subject: [PATCH 087/129] [style] SID: trim verbose comments and docstrings Comment/docstring-only cleanup across the RQ-VAE files touched this session (clip_loss, vector_quantize, residual_vector_quantizer, sid_rqvae): condense the multi-line explanations added with the gumbel / empty-mask / latent_weight fixes to one or two lines each. No code changed; ruff clean; SID suite green. --- tzrec/loss/clip_loss.py | 22 ++---- tzrec/models/sid_rqvae.py | 16 ++--- .../modules/sid/residual_vector_quantizer.py | 72 +++++++------------ tzrec/modules/sid/vector_quantize.py | 25 +++---- 4 files changed, 45 insertions(+), 90 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index 7fd1092a6..412150e2a 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -19,13 +19,9 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss -# Finite large-negative fill for masked-out (recon) logit columns. ``exp()`` of -# it underflows to 0 (same masking effect as ``-inf``), but unlike ``-inf`` it -# keeps cross-entropy and its gradient finite when a row has no valid column -# (a batch/rank with zero clip rows). With ``-inf`` such a row yields NaN and, -# crucially, a NaN gradient (``0 * NaN``) that survives the row mask and poisons -# the encoder. Finite, so a no-clip batch contributes exactly 0 with 0 gradient. -# Kept well within fp16 range (-65504) so it is safe under FP16/BF16 autocast. +# Fill for masked-out (recon) logit columns. exp() underflows to 0 like -inf, +# but stays finite so an all-recon row gives finite CE/grad (not 0*NaN). Within +# fp16 range, so safe under autocast. _MASKED_LOGIT_FILL = -1e4 @@ -111,15 +107,11 @@ def _masked_cross_entropy( """ ce_i = F.cross_entropy(logits_i, safe_labels, reduction="none") ce_t = F.cross_entropy(logits_t, safe_labels, reduction="none") - # Backstop only: the finite _MASKED_LOGIT_FILL already keeps the - # all-recon row finite, so this guards solely against a non-finite - # logit arriving from upstream (e.g. an overflowed logit_scale). + # Backstop against a non-finite upstream logit (e.g. overflowed scale). ce_i = torch.nan_to_num(ce_i, nan=0.0) ce_t = torch.nan_to_num(ce_t, nan=0.0) - # Row mask: only clip rows contribute; clamp(min=1) keeps a no-clip - # batch at 0 (not 0/0). Combined with the finite fill, a batch with no - # clip rows yields exactly 0 loss and 0 gradient. + # Only clip rows contribute; clamp(min=1) keeps a no-clip batch at 0. n_valid = clip_mask.float().sum().clamp(min=1) return ((ce_i + ce_t) * clip_mask.float()).sum() / (2 * n_valid) @@ -173,9 +165,7 @@ def forward( logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() - # --- Column mask: recon columns -> large-negative (not as negatives) --- - # Finite fill (not -inf) so an all-recon row keeps a finite, non-NaN - # gradient; see _MASKED_LOGIT_FILL. + # Mask recon columns out of the negatives (finite fill, see above). clip_mask_all = self._gather_bool_mask(clip_mask) col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 6e285cdb6..15273886c 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -91,11 +91,8 @@ def __init__( hidden_dims = ( list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] ) - # latent_weight defaults to (1.0, 0.5) when the user leaves the - # repeated field empty. It must be exactly [w1, w2] (encoder-side and - # codebook-side commitment weights); the quantizer unpacks it into two, - # so validate here with a field-named error instead of a cryptic - # unpack ValueError deep in ResidualVectorQuantizer. + # Empty -> default (1.0, 0.5); else must be exactly [w1, w2] (the + # quantizer unpacks two). Validate here for a field-named error. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) if len(latent_weight) != 2: raise ValueError( @@ -344,12 +341,9 @@ def update_train_metric( ) -> None: """Update train metric state. - Overrides the BaseSidModel no-op: RQ-VAE has a train-time - reconstruction (the decoder output), so it can report a train-path mse. - The eval metrics (mse / rel_loss / unique_sid_ratio over - ``predictions["x_hat"]`` and ``["codes"]``) are handled by - ``BaseSidModel.update_metric`` — SidRqvae emits ``x_hat``, so it needs - no ``update_metric`` override. + Overrides the BaseSidModel no-op: RQ-VAE has a train-time reconstruction + (the decoder output), so it reports a train-path mse. Eval metrics are + handled by ``BaseSidModel.update_metric`` (SidRqvae emits ``x_hat``). Args: predictions (dict): a dict of predicted result. diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 2492ea1e3..882bf54df 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -41,10 +41,8 @@ def faiss_residual_kmeans( from the first training batch — the same FAISS backend the offline RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - Device handling (CPU + GPU): the FAISS fit is always CPU (``samples`` is - copied to host as fp32 numpy), and the returned centroids are moved back to - ``samples.device``. So an RQ-VAE training on GPU gets GPU centroids while - the fit itself stays on CPU — no faiss-gpu build required. + CPU+GPU: the fit is always CPU (host fp32 numpy copy); centroids are + returned on ``samples.device`` — no faiss-gpu build needed. Args: samples (Tensor): data points, shape (N, D). @@ -163,22 +161,14 @@ def __init__( mode_enum = self._FORWARD_MODE_MAP[forward_mode] self._forward_mode = mode_enum is_gumbel = mode_enum == QuantizeForwardMode.GUMBEL_SOFTMAX - # Gumbel drives its own assignment, so Sinkhorn is incompatible (and the - # two would desync code vs embedding). Auto-disable rather than crash on - # the proto default (use_sinkhorn omitted -> True). + # Sinkhorn is incompatible with Gumbel; auto-disable (the proto default + # is on) instead of crashing. if is_gumbel and use_sinkhorn: - logger.warning( - "forward_mode=gumbel_softmax is incompatible with Sinkhorn; " - "disabling use_sinkhorn for this quantizer." - ) + logger.warning("gumbel_softmax: disabling incompatible use_sinkhorn.") use_sinkhorn = False - # The aggregate STE (and its rotation-trick variant) is the STE encoder - # gradient path; Gumbel carries its own, so the rotation trick is unused. + # Gumbel skips the aggregate STE, so the rotation trick is unused. if is_gumbel and rotation_trick: - logger.warning( - "rotation_trick has no effect with forward_mode=gumbel_softmax " - "(the aggregate STE is skipped); ignoring it." - ) + logger.warning("gumbel_softmax: rotation_trick has no effect; ignoring.") if isinstance(distance_type, str): distance_types = [distance_type] * n_layers @@ -357,8 +347,8 @@ def _quantize_layer( ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize one layer's residual via its ``VectorQuantize`` layer. - Returns the raw (un-STE'd) codebook vector so gradient still flows into - the codebook; STE is applied once on the aggregate in :meth:`forward`. + STE: raw codebook vector (STE applied on the aggregate in :meth:`forward`). + Gumbel: the soft embedding (carries grad directly). Args: layer_idx (int): quantization layer index. @@ -372,11 +362,9 @@ def _quantize_layer( layer = self.layers[layer_idx] out = layer.quantize(residual, temperature) if self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX and self.training: - # Gumbel: the soft differentiable embedding carries gradient to both - # the encoder and the codebook, so use it directly (no aggregate STE). + # Gumbel: soft embedding carries grad to encoder + codebook. return out.ids, out.embeddings - # STE / eval: re-look up the raw codebook vector (not the soft - # ``out.embeddings``); STE is applied once on the aggregate in forward. + # STE / eval: raw codebook vector; STE applied on the aggregate in forward. return out.ids, layer.lookup(out.ids) def forward( @@ -386,20 +374,16 @@ def forward( ) -> ResidualQuantizerOutput: """Forward the multi-layer residual quantization. - Two encoder-gradient regimes by ``forward_mode``: + Encoder gradient by ``forward_mode``: - - STE: the residual walk runs on the DETACHED input (the assignment is - non-differentiable), and the encoder gradient is re-attached once via - the aggregate STE in step 4. The codebook trains via the commitment - loss (the accumulated raw lookups keep grad). - - Gumbel-Softmax: the soft assignment is itself differentiable, so the - walk runs on the LIVE input and the gradient reaches the encoder and - codebook through the accumulated soft embeddings; the aggregate STE is - skipped. + - STE: walk the DETACHED input, re-attach encoder grad via the aggregate + STE (step 4); the codebook trains via the commitment loss. + - Gumbel: the soft assignment is differentiable, so walk the LIVE input + (grad reaches encoder + codebook through the soft embeddings) and skip + the aggregate STE. - Steps: (1) kmeans_init on the first training forward; (2) residual walk; - (3) mean per-layer commitment loss over the cumulative quants; - (4) aggregate STE (STE mode only). + Steps: (1) kmeans_init (first training forward); (2) residual walk; + (3) mean per-layer commitment loss; (4) aggregate STE (STE only). Args: input (Tensor): input embeddings, shape (B, D). @@ -409,27 +393,26 @@ def forward( ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, quantization_loss). """ - # Step 1: KMeans initialization (first training forward only) + # Step 1: KMeans init (first training forward only) if self.training: self.init_embed_(input) is_gumbel = self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX train_gumbel = is_gumbel and self.training - # Step 2: residual walk. Gumbel keeps the live input (its soft - # assignment carries encoder grad); STE detaches and re-attaches grad in - # step 4. cumulative[i] = sum after layer i. + # Step 2: residual walk. Gumbel walks the live input; STE detaches and + # re-attaches grad in step 4. cumulative[i] = sum after layer i. walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass( walk_input, temperature ) - # Step 3: aggregate per-layer commitment loss + # Step 3: mean per-layer commitment loss commitment_loss = torch.mean( torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) ) - # Step 4: aggregate STE (STE mode only; Gumbel already carries grad). + # Step 4: aggregate STE (STE only; Gumbel already carries grad) quants_trunc = aggregated_quants if self.training and not is_gumbel: if self.rotation_trick: @@ -447,11 +430,8 @@ def forward( def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: """Get codebook embedding weights for a specific layer. - Detached: the layer's ``get_codebook_embeddings`` returns the live - ``nn.Embedding.weight`` (a grad-requiring leaf, needed by the training - ``lookup`` path), but this is a read-only accessor for export/inspection, - so it returns a non-grad view — matching the K-Means sibling, whose - codebook is a buffer. + Detached read-only view for export/inspection (the layer's weight is a + grad leaf, needed by the training ``lookup`` path). Args: layer_idx (int): index of the quantization layer. diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 1b30cfb0d..eb8c942fc 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -35,11 +35,8 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Returns: Tensor: squared distances, shape (N, K). - Device-agnostic (pure torch); the (N, K) product is small (N is the batch - size). Kept branch-free (no data-dependent chunking on ``N``) so the - forward stays FX-traceable for torchrec's inference pipeline. NOT wrapped in - ``no_grad``: the Gumbel-Softmax path needs differentiable distances (the - STE/Sinkhorn callers run it under their own ``no_grad``). ``clamp`` is + Branch-free (FX-traceable). Not ``no_grad`` (Gumbel needs grad here; the + STE/Sinkhorn callers wrap it in their own ``no_grad``); ``clamp`` is out-of-place to stay autograd-safe. """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) @@ -191,12 +188,10 @@ def __init__( nn.init.kaiming_uniform_(self.embedding.weight) def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: - """Compute distances between input vectors and codebook entries. + """Compute L2/cosine distances between inputs and codebook entries. - Supports L2 and cosine distance metrics. NOT wrapped in ``no_grad``: - the Gumbel-Softmax path calls this directly and needs the gradient to - flow to the encoder; the STE/Sinkhorn path calls it inside - :meth:`_find_nearest_embedding`, which is itself ``no_grad``. + Not ``no_grad``: Gumbel calls this directly for the encoder gradient; + the STE/Sinkhorn path wraps it via ``no_grad`` :meth:`_find_nearest_embedding`. Args: x (Tensor): input vectors, shape (B, D). @@ -278,13 +273,9 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: Returns: QuantizeOutput: named tuple of (embeddings, ids). """ - # Gumbel-Softmax: distances must be differentiable so the gradient - # reaches the encoder through the soft assignment, so they are computed - # WITH grad here (the STE/eval branch below assigns under no_grad). The - # straight-through hard sample drives BOTH the embedding and ``ids``, so - # the saved code matches the codebook vector actually reconstructed - # (unlike argmin, which the gumbel noise can disagree with). Sinkhorn is - # disabled for this mode (see ResidualVectorQuantizer.__init__). + # Gumbel: grad-enabled distances (so the encoder gets gradient); the + # hard sample drives both emb and ids, so the saved code matches the + # vector used. Sinkhorn is off here (ResidualVectorQuantizer.__init__). if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: logits = -self._compute_distances(x) # (B, n_embed), differentiable weights = _gumbel_softmax_sample(logits, temperature=temperature, hard=True) From 441cf194d4935940d57b728089a476ee0e31b8cf Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 11 Jun 2026 12:10:25 +0000 Subject: [PATCH 088/129] [simplify] RQ-VAE: dedup gumbel predicate; move latent_weight check to quantizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two cleanups from /simplify (no behavior change): - The "training pass in Gumbel mode" predicate was spelled three ways across ResidualVectorQuantizer._quantize_layer and forward. Collapse into one _train_gumbel() helper used by both (drops the is_gumbel/train_gumbel pair). - Move the latent_weight arity check from SidRqvae.__init__ into ResidualVectorQuantizer.__init__, right where it unpacks w1, w2 — the contract's owner, so any caller (not just SidRqvae) gets the field-named error instead of a cryptic unpack ValueError. SidRqvae keeps only the empty -> default handling. ruff clean; sid_rqvae / residual_vector_quantizer(+dist) tests pass (the latent_weight wrong-length test still raises via the quantizer). --- tzrec/models/sid_rqvae.py | 8 +------- tzrec/modules/sid/residual_vector_quantizer.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 15273886c..d888c3ede 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -91,14 +91,8 @@ def __init__( hidden_dims = ( list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] ) - # Empty -> default (1.0, 0.5); else must be exactly [w1, w2] (the - # quantizer unpacks two). Validate here for a field-named error. + # Empty -> default (1.0, 0.5); the quantizer validates the arity. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) - if len(latent_weight) != 2: - raise ValueError( - "latent_weight must have exactly 2 values [w1, w2], got " - f"{list(cfg.latent_weight)}" - ) use_sinkhorn = True sinkhorn_iters = 5 diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 882bf54df..524afd18e 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -147,6 +147,11 @@ def __init__( self.commitment_loss_type = commitment_loss self.rotation_trick = rotation_trick + if len(latent_weight) != 2: + raise ValueError( + f"latent_weight must have exactly 2 values [w1, w2], got " + f"{list(latent_weight)}" + ) self.commitment_w1, self.commitment_w2 = latent_weight # ``initted`` is the kmeans_init guard: True means "codebook has @@ -339,6 +344,11 @@ def _apply_rotation_trick( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) + def _train_gumbel(self) -> bool: + """Training pass in Gumbel mode (its soft assignment carries grad).""" + is_gumbel = self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + return self.training and is_gumbel + def _quantize_layer( self, layer_idx: int, @@ -361,7 +371,7 @@ def _quantize_layer( """ layer = self.layers[layer_idx] out = layer.quantize(residual, temperature) - if self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX and self.training: + if self._train_gumbel(): # Gumbel: soft embedding carries grad to encoder + codebook. return out.ids, out.embeddings # STE / eval: raw codebook vector; STE applied on the aggregate in forward. @@ -397,8 +407,7 @@ def forward( if self.training: self.init_embed_(input) - is_gumbel = self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX - train_gumbel = is_gumbel and self.training + train_gumbel = self._train_gumbel() # Step 2: residual walk. Gumbel walks the live input; STE detaches and # re-attaches grad in step 4. cumulative[i] = sum after layer i. @@ -414,7 +423,7 @@ def forward( # Step 4: aggregate STE (STE only; Gumbel already carries grad) quants_trunc = aggregated_quants - if self.training and not is_gumbel: + if self.training and not train_gumbel: if self.rotation_trick: quants_trunc = self._apply_rotation_trick(input, quants_trunc) else: From ba1d7f9d4fa753d181874b197116dbfbcc97ca5c Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 06:19:28 +0000 Subject: [PATCH 089/129] [fix] code-review: structural CLIP mask fill (finfo.min) + drop dead VQ branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two findings from the high-effort code review: - #1 CLIP masked-column fill: the hardcoded -1e4 is not provably below real logits (logit_scale is an unclamped exp() and the *_ori operands are not L2-normalized, matching the al_sid reference), so a masked recon column could become competitive in the softmax/eval-argmax and leak into the loss. Replace it with torch.finfo(logits.dtype).min — below any real finite logit (masks like -inf regardless of scale) yet finite, so an all-recon row still gives a finite CE/grad instead of 0*NaN. Verified on a scale=5000, ori_norm=50 stress case that would have defeated -1e4. - #4 dead branch: after the gumbel early-return, STE is the only remaining training mode (the enum has two members), so `elif self.training: raise` was unreachable. Collapse the tail to a clean train(STE)/eval if-else. ruff clean; clip_loss / vector_quantize / sid_rqvae tests pass. --- tzrec/loss/clip_loss.py | 25 ++++++++++++------------- tzrec/modules/sid/vector_quantize.py | 9 ++++----- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index 412150e2a..abc569ed1 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -19,11 +19,6 @@ from torch.nn import functional as F from torch.nn.modules.loss import _Loss -# Fill for masked-out (recon) logit columns. exp() underflows to 0 like -inf, -# but stays finite so an all-recon row gives finite CE/grad (not 0*NaN). Within -# fp16 range, so safe under autocast. -_MASKED_LOGIT_FILL = -1e4 - class MaskedCLIPLoss(_Loss): """Masked CLIP loss for mixed recon+clip batches. @@ -165,16 +160,20 @@ def forward( logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() - # Mask recon columns out of the negatives (finite fill, see above). + # Mask recon columns out of the negatives. Fill with the dtype's most + # negative finite value: provably below any real logit (so it masks like + # -inf regardless of logit_scale), but finite so an all-recon row gives + # a finite CE/grad instead of 0*NaN. clip_mask_all = self._gather_bool_mask(clip_mask) col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) - - logits_img_self = logits_img_self.masked_fill(col_mask, _MASKED_LOGIT_FILL) - logits_txt_self = logits_txt_self.masked_fill(col_mask, _MASKED_LOGIT_FILL) - logits_img_ori = logits_img_ori.masked_fill(col_mask, _MASKED_LOGIT_FILL) - logits_txt_ori = logits_txt_ori.masked_fill(col_mask, _MASKED_LOGIT_FILL) - logits_img_cl = logits_img_cl.masked_fill(col_mask, _MASKED_LOGIT_FILL) - logits_txt_cl = logits_txt_cl.masked_fill(col_mask, _MASKED_LOGIT_FILL) + neg_fill = torch.finfo(logits_img_self.dtype).min + + logits_img_self = logits_img_self.masked_fill(col_mask, neg_fill) + logits_txt_self = logits_txt_self.masked_fill(col_mask, neg_fill) + logits_img_ori = logits_img_ori.masked_fill(col_mask, neg_fill) + logits_txt_ori = logits_txt_ori.masked_fill(col_mask, neg_fill) + logits_img_cl = logits_img_cl.masked_fill(col_mask, neg_fill) + logits_txt_cl = logits_txt_cl.masked_fill(col_mask, neg_fill) # --- Safe labels: recon rows fallback to first clip column --- labels = self.labels diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index eb8c942fc..d0e30db5b 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -283,14 +283,13 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) - # STE / eval: nearest-neighbour assignment under no_grad. + # STE / eval: nearest-neighbour assignment under no_grad. (Gumbel + # early-returned above; STE is the only remaining training mode.) ids, _ = self._find_nearest_embedding(x) - if self.training and self.forward_mode == QuantizeForwardMode.STE: + if self.training: + # Straight-Through Estimator: gradient passes through. quantized = self.embedding(ids) - # Straight-Through Estimator: gradient passes through emb = x + (quantized - x).detach() - elif self.training: - raise ValueError(f"Unsupported forward mode: {self.forward_mode}") else: emb = self.embedding(ids) From c5102122de694cd2fda88fe99db3a194e0c92965 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 06:40:40 +0000 Subject: [PATCH 090/129] [refactor] SID: share one-layer FAISS fit; tests for modified modules (#6) Code-review #6: faiss_residual_kmeans (RQ-VAE warm-start) and ResidualKMeansQuantizer.train_offline duplicated the per-layer FAISS loop, and the warm-start copy lacked train_offline's gpu-kwarg strip and N>=K guard. Extract the shared one-layer primitive `faiss_kmeans_fit(x, dim, k, kwargs)` in kmeans_quantize.py (the SID k-means home): it imports faiss, strips a stale `gpu` kwarg, guards N>=K with a clear message, fits, and returns the trained faiss.Kmeans. Each caller keeps its own residual loop (the RQ-Kmeans chunked search / normalize / logging and the RQ-VAE device handling stay put). The RQ-VAE warm-start gains the N>=K guard for free. Tests (CPU-only; the FAISS fit never uses GPU): - kmeans_quantize_test: faiss_kmeans_fit fits + raises on N None: self.assertIsNotNone(feats["image_embed"].grad) self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + def test_mask_holds_under_large_scale(self) -> None: + # The column fill is finfo.min (below any real logit) rather than a + # hardcoded -1e4, so masking holds even when logit_scale is large and + # the *_ori operands are un-normalized (real logits can dwarf 1e4). + # Loss/grad must stay finite and acc valid; eval exercises the argmax. + loss_fn = MaskedCLIPLoss() + loss_fn.eval() + feats = self._features(6, 8) + big = torch.tensor(3000.0) + feats["logit_scale"] = big + feats["logit_scale_self"] = big + feats["logit_scale_cl"] = big + feats["image_embed_ori"] = feats["image_embed_ori"] * 50 + feats["text_embed_ori"] = feats["text_embed_ori"] * 50 + mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) + out = loss_fn(feats, mask) + self.assertTrue(torch.isfinite(out["clip_loss"])) + loss_fn.train() + feats["image_embed"].grad = None + loss_fn(feats, mask)["clip_loss"].backward() + self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 6eb5b940a..a4e12554a 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -19,9 +19,11 @@ * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. +* :func:`faiss_kmeans_fit` — the shared one-layer FAISS fit behind both SID + residual-K-Means loops (RQ-VAE warm-start and offline RQ-K-Means). """ -from typing import Optional +from typing import Any, Dict, Optional import torch @@ -30,6 +32,60 @@ from tzrec.utils.logging_util import logger +def faiss_kmeans_fit( + x: Any, + dim: int, + n_clusters: int, + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> Any: + """Train one ``faiss.Kmeans(dim, n_clusters)`` on ``x`` and return it. + + The shared one-layer FAISS fit behind both SID residual-K-Means loops — the + RQ-VAE warm-start + (:func:`~tzrec.modules.sid.residual_vector_quantizer.faiss_residual_kmeans`) + and the offline RQ-K-Means + (:meth:`~tzrec.modules.sid.residual_kmeans_quantizer.ResidualKMeansQuantizer.train_offline`). + The caller reads ``km.centroids`` and runs assignment via + ``km.index.search``, keeping its own residual / chunking / device logic. + + Strips a stale ``gpu`` kwarg (a faiss-gpu build must not target an absent + GPU) and guards ``N >= n_clusters`` with a clear message before faiss's + opaque C++ throw. ``x`` may be a numpy array or a torch tensor + (``faiss.contrib.torch_utils``). + + Args: + x: data points, shape (N, dim) — numpy array or torch tensor. + dim (int): feature dimension. + n_clusters (int): number of centroids (codebook size). + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans``. + + Returns: + The trained ``faiss.Kmeans`` (read ``.centroids`` / ``.index``). + + Raises: + ImportError: if ``faiss`` is not installed. + RuntimeError: if ``x`` has fewer than ``n_clusters`` rows. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for SID residual K-Means. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + kwargs = dict(faiss_kmeans_kwargs or {}) + kwargs.pop("gpu", None) + n = int(x.shape[0]) + if n < n_clusters: + raise RuntimeError( + f"need >= {n_clusters} points to fit the codebook, got N={n}" + ) + km = faiss.Kmeans(dim, n_clusters, **kwargs) + km.train(x) + return km + + class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 2f2883562..59ec45a01 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,9 +16,33 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, + faiss_kmeans_fit, ) +class FaissKmeansFitTest(unittest.TestCase): + """Tests for the shared one-layer FAISS fit primitive.""" + + def _require_faiss(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + def test_fit_returns_trained_kmeans(self) -> None: + self._require_faiss() + torch.manual_seed(0) + # numpy input (the RQ-VAE call path; no faiss torch-utils needed). + x = torch.randn(200, 6).numpy() + km = faiss_kmeans_fit(x, 6, 8, {"niter": 5, "seed": 1, "verbose": False}) + self.assertEqual(tuple(km.centroids.shape), (8, 6)) + + def test_raises_on_too_few_points(self) -> None: + self._require_faiss() + with self.assertRaisesRegex(RuntimeError, "need >= 8 points"): + faiss_kmeans_fit(torch.randn(4, 6).numpy(), 6, 8) + + class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 11b06951c..2ebee0ec2 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -17,13 +17,12 @@ from typing import Dict, List, Optional, Tuple, Union -import faiss import faiss.contrib.torch_utils # noqa: F401 (registers torch tensor I/O) import torch from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, faiss_kmeans_fit from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -202,10 +201,9 @@ def train_offline( # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit (SidRqkmeans refuses CUDA). Drop any stale ``gpu`` kwarg - # so a faiss-gpu build can't target an absent GPU. + # CPU-only fit (SidRqkmeans refuses CUDA). The ``gpu`` kwarg is stripped + # inside faiss_kmeans_fit. kwargs = dict(self.faiss_kmeans_kwargs) - kwargs.pop("gpu", None) if verbose: logger.info( "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " @@ -224,15 +222,14 @@ def train_offline( # Fresh Kmeans per layer so each can use its own K (non-uniform # codebooks). - kmeans = faiss.Kmeans( - self.embed_dim, self.n_embed_list[layer_idx], **kwargs + km = faiss_kmeans_fit( + x, self.embed_dim, self.n_embed_list[layer_idx], kwargs ) - kmeans.train(x) - centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32) + centroids = torch.as_tensor(km.centroids, dtype=torch.float32) for start in range(0, N, SEARCH_CHUNK): end = min(start + SEARCH_CHUNK, N) - _, idx = kmeans.index.search(x[start:end], 1) + _, idx = km.index.search(x[start:end], 1) idx = torch.as_tensor(idx).reshape(-1).long() q = centroids[idx] # (chunk, D) out[start:end] += q diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 524afd18e..227de4267 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -18,6 +18,7 @@ from torch import nn from torch.nn import functional as F +from tzrec.modules.sid.kmeans_quantize import faiss_kmeans_fit from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.modules.sid.types import ( QuantizeForwardMode, @@ -55,16 +56,8 @@ def faiss_residual_kmeans( Raises: ImportError: if ``faiss`` is not installed. + RuntimeError: if a layer has fewer points than its cluster count. """ - try: - import faiss - except ImportError as e: - raise ImportError( - "faiss is required for RQ-VAE kmeans_init. Install via " - "`pip install faiss-cpu` or `pip install faiss-gpu`." - ) from e - - kwargs = dict(faiss_kmeans_kwargs or {}) device = samples.device _, D = samples.shape # Own a contiguous fp32 numpy copy we mutate in place to form residuals. @@ -72,11 +65,10 @@ def faiss_residual_kmeans( res_centers: List[torch.Tensor] = [] for n_clusters in n_clusters_list: - kmeans = faiss.Kmeans(D, n_clusters, **kwargs) - kmeans.train(x) - centroids = kmeans.centroids.copy() # (K, D) + km = faiss_kmeans_fit(x, D, n_clusters, faiss_kmeans_kwargs) + centroids = km.centroids.copy() # (K, D) res_centers.append(torch.from_numpy(centroids).to(device)) - _, idx = kmeans.index.search(x, 1) + _, idx = km.index.search(x, 1) x -= centroids[idx.ravel()] # residual, in place return res_centers diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 9d80bd42b..1c71936bf 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -96,6 +96,16 @@ def test_faiss_residual_kmeans_per_layer_centers(self) -> None: # Centroids come back on the input device (CPU fit, device-preserving). self.assertEqual(centers[0].device, samples.device) + def test_raises_on_too_few_points(self) -> None: + # Gained from the shared faiss_kmeans_fit primitive: a clear N>=K error + # before faiss's opaque C++ throw. + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + with self.assertRaisesRegex(RuntimeError, "need >= 8 points"): + faiss_residual_kmeans(torch.randn(4, 6), [8]) + if __name__ == "__main__": unittest.main() From 6596b9e7c2bb6897e5f61b487a91003ef501b834 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 08:11:28 +0000 Subject: [PATCH 091/129] [style] SID: ruff-format residual_quantizer_test (collapse one assertEqual) CI codestyle lane (ruff-format v0.15.11) reformatted this line; apply it so the pre-commit ruff-format hook passes. Formatter-only, no logic change. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/gen_sid_rqvae_mock_data.py | 72 ++++++++++++++++++ examples/sid_rqvae_gumbel_clip_local.config | 78 ++++++++++++++++++++ tzrec/modules/sid/residual_quantizer_test.py | 4 +- 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 examples/gen_sid_rqvae_mock_data.py create mode 100644 examples/sid_rqvae_gumbel_clip_local.config diff --git a/examples/gen_sid_rqvae_mock_data.py b/examples/gen_sid_rqvae_mock_data.py new file mode 100644 index 000000000..8f753c9d0 --- /dev/null +++ b/examples/gen_sid_rqvae_mock_data.py @@ -0,0 +1,72 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Write a mock embedding parquet for the SidRqvae Gumbel+CLIP smoke config. + +Columns match ``examples/sid_rqvae_gumbel_clip_local.config``: + item1_embedding (list[dim]) -- the SID input embedding + item2_embedding (list[dim]) -- the CLIP-paired embedding + is_contrastive (float32 scalar) -- 1.0 = CLIP pair, 0.0 = recon-only + +Usage: + python examples/gen_sid_rqvae_mock_data.py --out_dir ./tmp/sid_rqvae_mock \ + --num_rows 4096 --dim 512 --clip_ratio 0.5 +""" + +import argparse +import os + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + + +def main() -> None: + """Generate the mock parquet shard.""" + parser = argparse.ArgumentParser() + parser.add_argument("--out_dir", default="./tmp/sid_rqvae_mock") + parser.add_argument("--num_rows", type=int, default=4096) + parser.add_argument("--dim", type=int, default=512) + parser.add_argument( + "--clip_ratio", + type=float, + default=0.5, + help="fraction of rows flagged as CLIP pairs (is_contrastive=1)", + ) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + rng = np.random.default_rng(args.seed) + item1 = rng.standard_normal((args.num_rows, args.dim)).astype(np.float32) + # item2 is a noisy view of item1 so the contrastive pairs are learnable. + item2 = (item1 + 0.1 * rng.standard_normal(item1.shape)).astype(np.float32) + is_clip = (rng.random(args.num_rows) < args.clip_ratio).astype(np.float32) + + os.makedirs(args.out_dir, exist_ok=True) + out_path = os.path.join(args.out_dir, "part-0.parquet") + pq.write_table( + pa.table( + { + "item1_embedding": pa.array(list(item1)), + "item2_embedding": pa.array(list(item2)), + "is_contrastive": pa.array(is_clip), + } + ), + out_path, + ) + print( + f"wrote {args.num_rows} rows (dim={args.dim}, " + f"clip_pairs={int(is_clip.sum())}) -> {out_path}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/sid_rqvae_gumbel_clip_local.config b/examples/sid_rqvae_gumbel_clip_local.config new file mode 100644 index 000000000..ed24b9ce3 --- /dev/null +++ b/examples/sid_rqvae_gumbel_clip_local.config @@ -0,0 +1,78 @@ +# SidRqvae end-to-end smoke config — Gumbel-Softmax forward mode + CLIP. +# +# Ported from ft_scripts/sid_rqvae_clip_8192.feat_abstract.config to the +# feat/sid_abstract schema (repeated uint32 / float as one value per line), +# downsized for a single-box CPU run, and switched to Gumbel-Softmax: +# forward_mode: "ste" -> "gumbel_softmax" +# Notes on Gumbel: +# * kmeans_init stays false: Gumbel trains the codebook by gradient, so it +# needs no FAISS warm-start (and thus no faiss at init time). +# * Sinkhorn is auto-disabled under Gumbel by SidRqvae; no sinkhorn_config +# block is needed. +# Generate the matching mock parquet first: +# python examples/gen_sid_rqvae_mock_data.py --out_dir ./tmp/sid_rqvae_mock +model_dir: "experiments/sid_rqvae_gumbel_clip_local" + +train_config { + sparse_optimizer { + adam_optimizer { lr: 0.002 beta1: 0.9 beta2: 0.999 weight_decay: 0.0001 } + constant_learning_rate {} + } + dense_optimizer { + adamw_optimizer { lr: 0.002 beta1: 0.9 beta2: 0.999 weight_decay: 0.0001 } + constant_learning_rate {} + } + num_epochs: 2 + save_checkpoints_steps: 100 + log_step_count_steps: 5 + is_profiling: false +} + +eval_config {} + +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 + drop_remainder: true +} + +feature_configs { + raw_feature { feature_name: "item1_emb" expression: "item:item1_embedding" value_dim: 512 } +} +feature_configs { + raw_feature { feature_name: "item2_emb" expression: "item:item2_embedding" value_dim: 512 } +} +feature_configs { + raw_feature { feature_name: "is_clip_pair" expression: "item:is_contrastive" value_dim: 1 } +} + +model_config { + feature_groups { + group_name: "deep" + feature_names: "item1_emb" + feature_names: "item2_emb" + feature_names: "is_clip_pair" + group_type: DEEP + } + sid_rqvae { + input_dim: 512 + embed_dim: 64 + hidden_dims: 256 + hidden_dims: 256 + codebook: 256 + codebook: 256 + codebook: 256 + forward_mode: "gumbel_softmax" + loss_type: "mse" + kmeans_init: false + latent_weight: 0.5 + latent_weight: 0.5 + embedding_feature_name: "item1_emb" + clip_config { + clip_feature_name: "item2_emb" + is_clip_pair_feature_name: "is_clip_pair" + } + } +} diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index 346c43c41..d37ef7614 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -212,9 +212,7 @@ def test_is_subclass(self) -> None: def test_non_uniform_codebook_supported(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) self.assertEqual(rkq.n_embed_list, [8, 4, 16]) - self.assertEqual( - [layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16] - ) + self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) def test_forward_returns_zeros_before_fit(self) -> None: rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) From 9c2e872df863bd93e3e6a586ac5bb089dc5d2433 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 09:00:16 +0000 Subject: [PATCH 092/129] [style] SID: simplify verbose comments/docstrings in RQ-VAE stack MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim multi-sentence docstring prose and redundant inline comments across vector_quantize, residual_vector_quantizer, sid_rqvae, clip_loss, and the shared faiss_kmeans_fit. Comment/docstring-only — no logic, type-hint, or behavior change; ruff clean, SID unit tests green. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss.py | 23 ++++---- tzrec/models/sid_rqvae.py | 7 +-- tzrec/modules/sid/kmeans_quantize.py | 17 ++---- .../modules/sid/residual_vector_quantizer.py | 50 ++++++----------- tzrec/modules/sid/vector_quantize.py | 53 ++++++------------- 5 files changed, 49 insertions(+), 101 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index abc569ed1..5f4a0b56a 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -23,11 +23,9 @@ class MaskedCLIPLoss(_Loss): """Masked CLIP loss for mixed recon+clip batches. - In a mixed batch, recon rows (clip_mask=False) should not - contribute to CLIP loss, and recon columns should not serve as - negatives. This module applies row and column masks to achieve - selective contrastive learning without data-dependent branching, - ensuring ``torch.compile`` compatibility. + In a mixed batch, recon rows (clip_mask=False) must not contribute to the + CLIP loss, and recon columns must not serve as negatives. Row/column masks + achieve this without data-dependent branching (``torch.compile``-friendly). Input dict keys: 'image_embed': (B, D) quantized output of first feature @@ -56,11 +54,9 @@ def __init__(self) -> None: def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: """All-gather tensors across workers with gradient support. - In single-process mode, returns the input tensors unchanged. In - multi-process mode, uses ``torch.distributed.nn.functional - .all_gather`` — the built-in differentiable collective (its backward - sum-reduces the per-rank grads and returns this rank's slice), so no - custom ``autograd.Function`` is needed. + Single-process: returns the inputs unchanged. Multi-process: uses the + built-in differentiable ``torch.distributed.nn.functional.all_gather``, + so no custom ``autograd.Function`` is needed. Args: tensors (List[Tensor]): list of tensors to gather. @@ -160,10 +156,9 @@ def forward( logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() - # Mask recon columns out of the negatives. Fill with the dtype's most - # negative finite value: provably below any real logit (so it masks like - # -inf regardless of logit_scale), but finite so an all-recon row gives - # a finite CE/grad instead of 0*NaN. + # Mask recon columns out of the negatives with the dtype's most negative + # finite value: below any real logit (masks like -inf), but finite so an + # all-recon row gives a finite CE/grad instead of 0*NaN. clip_mask_all = self._gather_bool_mask(clip_mask) col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) neg_fill = torch.finfo(logits_img_self.dtype).min diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index d888c3ede..2caf1efa7 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -321,11 +321,8 @@ def init_metric(self) -> None: """Initialize metric modules (shared eval metrics + train-path mse).""" super().init_metric() - # Loss values are already logged by the framework via loss(); only - # quantization quality needs the train-path metric. unique_sid_ratio - # is intentionally eval-only: torch.unique(codes, dim=0).shape[0] - # forces a GPU->host sync every step, and codebook coverage is a - # diagnostic, not a training signal. + # Only the train-path reconstruction needs a metric here; unique_sid_ratio + # is eval-only (its torch.unique forces a per-step GPU->host sync). self._train_metric_modules["mse"] = torchmetrics.MeanSquaredError() def update_train_metric( diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index a4e12554a..36a95e435 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -40,18 +40,11 @@ def faiss_kmeans_fit( ) -> Any: """Train one ``faiss.Kmeans(dim, n_clusters)`` on ``x`` and return it. - The shared one-layer FAISS fit behind both SID residual-K-Means loops — the - RQ-VAE warm-start - (:func:`~tzrec.modules.sid.residual_vector_quantizer.faiss_residual_kmeans`) - and the offline RQ-K-Means - (:meth:`~tzrec.modules.sid.residual_kmeans_quantizer.ResidualKMeansQuantizer.train_offline`). - The caller reads ``km.centroids`` and runs assignment via - ``km.index.search``, keeping its own residual / chunking / device logic. - - Strips a stale ``gpu`` kwarg (a faiss-gpu build must not target an absent - GPU) and guards ``N >= n_clusters`` with a clear message before faiss's - opaque C++ throw. ``x`` may be a numpy array or a torch tensor - (``faiss.contrib.torch_utils``). + The shared one-layer FAISS fit behind both SID residual-K-Means loops (the + RQ-VAE warm-start and the offline RQ-K-Means); the caller reads + ``km.centroids`` and assigns via ``km.index.search``. Strips a stale ``gpu`` + kwarg and guards ``N >= n_clusters`` before faiss's opaque C++ throw. ``x`` + may be a numpy array or a torch tensor. Args: x: data points, shape (N, dim) — numpy array or torch tensor. diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 227de4267..d3c58c285 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -36,14 +36,11 @@ def faiss_residual_kmeans( ) -> List[torch.Tensor]: """Residual K-Means warm-start via FAISS, one pass per layer. - Clusters ``samples`` with FAISS K-Means, subtracts each point's assigned - centroid, and repeats on the residual for every layer. Used by - :meth:`ResidualVectorQuantizer.init_embed_` to seed the RQ-VAE codebook - from the first training batch — the same FAISS backend the offline - RQ-KMeans model uses, instead of a separate torch-native Lloyd's loop. - - CPU+GPU: the fit is always CPU (host fp32 numpy copy); centroids are - returned on ``samples.device`` — no faiss-gpu build needed. + Clusters ``samples``, subtracts each point's assigned centroid, and repeats + on the residual per layer. Seeds the RQ-VAE codebook (via + :meth:`ResidualVectorQuantizer.init_embed_`) from the first training batch. + The fit is always CPU (host fp32 numpy copy); centroids return on + ``samples.device`` — no faiss-gpu build needed. Args: samples (Tensor): data points, shape (N, D). @@ -217,13 +214,10 @@ def __init__( def init_embed_(self, data: torch.Tensor) -> None: """Initialize codebook weights via FAISS residual K-Means. - Only executed once when kmeans_init=True and not yet initialized. - Uses the first batch of training data as the initialization pool. - - Under DDP the codebook is fit on rank 0 only and broadcast, so every - rank starts from the SAME codebook. (Averaging per-rank centroids — - the previous behavior — mixes permutation-misaligned clusters across - ranks and yields a near-random warm start.) + Runs once (kmeans_init=True, not yet initialized), seeding from the first + training batch. Under DDP the fit happens on rank 0 and is broadcast, so + every rank starts from the same codebook (averaging per-rank centroids + would mix permutation-misaligned clusters into a near-random start). Args: data (Tensor): input data, shape (B, D). @@ -364,8 +358,7 @@ def _quantize_layer( layer = self.layers[layer_idx] out = layer.quantize(residual, temperature) if self._train_gumbel(): - # Gumbel: soft embedding carries grad to encoder + codebook. - return out.ids, out.embeddings + return out.ids, out.embeddings # soft embedding carries grad # STE / eval: raw codebook vector; STE applied on the aggregate in forward. return out.ids, layer.lookup(out.ids) @@ -376,16 +369,10 @@ def forward( ) -> ResidualQuantizerOutput: """Forward the multi-layer residual quantization. - Encoder gradient by ``forward_mode``: - - - STE: walk the DETACHED input, re-attach encoder grad via the aggregate - STE (step 4); the codebook trains via the commitment loss. - - Gumbel: the soft assignment is differentiable, so walk the LIVE input - (grad reaches encoder + codebook through the soft embeddings) and skip - the aggregate STE. - - Steps: (1) kmeans_init (first training forward); (2) residual walk; - (3) mean per-layer commitment loss; (4) aggregate STE (STE only). + Encoder gradient by ``forward_mode``: STE walks the DETACHED input and + re-attaches grad via the aggregate STE below (codebook trains via the + commitment loss); Gumbel's soft assignment is differentiable, so it walks + the LIVE input and skips the aggregate STE. Args: input (Tensor): input embeddings, shape (B, D). @@ -395,25 +382,22 @@ def forward( ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, quantization_loss). """ - # Step 1: KMeans init (first training forward only) if self.training: - self.init_embed_(input) + self.init_embed_(input) # first training forward only train_gumbel = self._train_gumbel() - # Step 2: residual walk. Gumbel walks the live input; STE detaches and - # re-attaches grad in step 4. cumulative[i] = sum after layer i. + # cumulative[i] = sum after layer i. walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass( walk_input, temperature ) - # Step 3: mean per-layer commitment loss commitment_loss = torch.mean( torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) ) - # Step 4: aggregate STE (STE only; Gumbel already carries grad) + # Aggregate STE (STE only; Gumbel already carries grad). quants_trunc = aggregated_quants if self.training and not train_gumbel: if self.rotation_trick: diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index d0e30db5b..aeae0f109 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -35,9 +35,8 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso Returns: Tensor: squared distances, shape (N, K). - Branch-free (FX-traceable). Not ``no_grad`` (Gumbel needs grad here; the - STE/Sinkhorn callers wrap it in their own ``no_grad``); ``clamp`` is - out-of-place to stay autograd-safe. + Grad-enabled and branch-free (Gumbel needs grad; STE/Sinkhorn callers add + their own ``no_grad``). """ x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) @@ -126,18 +125,11 @@ def _sinkhorn( class VectorQuantize(QuantizeLayer): """Single codebook vector quantization layer (RQ-VAE backend). - The VQ :class:`~tzrec.modules.sid.quantize_layer.QuantizeLayer`: a - gradient-trained ``nn.Embedding`` codebook, the sibling of the K-Means - backend's :class:`~tzrec.modules.sid.kmeans_quantize.KMeansQuantizeLayer`. - Maps continuous input vectors to a codebook entry and returns the quantized - embeddings + codebook indices via :meth:`quantize`. The commitment loss is - computed at the residual-aggregator level by - :meth:`ResidualVectorQuantizer._single_commitment_loss` over the cumulative - quants (matching al_sid's ``RQBottleneck.compute_commitment_loss``); - this layer is intentionally loss-free. - - During training, Sinkhorn optimal-transport assignment is optionally - used to encourage uniform codebook utilization. + A gradient-trained ``nn.Embedding`` codebook (the VQ ``QuantizeLayer``), + sibling of the K-Means backend's ``KMeansQuantizeLayer``. Maps inputs to a + codebook entry via :meth:`quantize`. Loss-free: the commitment loss lives in + :meth:`ResidualVectorQuantizer._single_commitment_loss`. Sinkhorn + optimal-transport assignment optionally balances codebook usage in training. Args: embed_dim (int): dimension of each codebook embedding. @@ -164,12 +156,9 @@ def __init__( sinkhorn_epsilon: float = 10.0, ) -> None: super().__init__(n_embed=n_embed, embed_dim=embed_dim) - # Sinkhorn + Gumbel-Softmax pick the code by two different rules: - # `ids` come from the Sinkhorn balanced-assignment argmax, while the - # Gumbel branch builds `emb` from argmax(-distances + noise) (nearest - # code). The two indices generally disagree, so the saved SID would not - # match the codebook vector actually reconstructed/trained. STE avoids - # this by looking up embedding(ids) directly. Force a consistent combo. + # Sinkhorn drives `ids` (balanced assignment), Gumbel drives `emb` + # (nearest code); combining them makes the saved id and embedding + # diverge, so reject the combo (see the assert message). _is_gumbel = forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX assert not (use_sinkhorn and _is_gumbel), ( "use_sinkhorn=True is incompatible with forward_mode=GUMBEL_SOFTMAX: " @@ -255,16 +244,9 @@ def _find_nearest_embedding( def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). - Training flow: - 1. compute distances (L2 or cosine) - 2. if use_sinkhorn: z-score normalize + Sinkhorn -> argmax - else: argmin - 3. compute differentiable embedding (STE or Gumbel-Softmax) - Commitment loss is computed by the caller - (:meth:`ResidualVectorQuantizer._single_commitment_loss`). Device follows - ``x`` (and the codebook, which moves with the module), so this runs on - CPU or GPU unchanged. + (:meth:`ResidualVectorQuantizer._single_commitment_loss`); device follows + ``x``, so this runs on CPU or GPU unchanged. Args: x (Tensor): input vectors, shape (B, D). @@ -273,9 +255,8 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: Returns: QuantizeOutput: named tuple of (embeddings, ids). """ - # Gumbel: grad-enabled distances (so the encoder gets gradient); the - # hard sample drives both emb and ids, so the saved code matches the - # vector used. Sinkhorn is off here (ResidualVectorQuantizer.__init__). + # Gumbel: grad-enabled distances feed the encoder; the hard sample drives + # both emb and ids, so the saved code matches the vector used. if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: logits = -self._compute_distances(x) # (B, n_embed), differentiable weights = _gumbel_softmax_sample(logits, temperature=temperature, hard=True) @@ -283,12 +264,10 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) - # STE / eval: nearest-neighbour assignment under no_grad. (Gumbel - # early-returned above; STE is the only remaining training mode.) + # STE / eval: nearest-neighbour assignment under no_grad. ids, _ = self._find_nearest_embedding(x) if self.training: - # Straight-Through Estimator: gradient passes through. - quantized = self.embedding(ids) + quantized = self.embedding(ids) # straight-through: grad passes to x emb = x + (quantized - x).detach() else: emb = self.embedding(ids) From 81c6bb76aebcd37520d2e2e8f8b39185f7c82981 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 10:02:06 +0000 Subject: [PATCH 093/129] [chore] SID: remove stray gumbel example config + mock-data generator These were unrelated local smoke-test scaffolding that got swept into the PR when they were left staged across later commits. They are not part of the SidRqvae feature; drop them from the branch. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/gen_sid_rqvae_mock_data.py | 72 ------------------- examples/sid_rqvae_gumbel_clip_local.config | 78 --------------------- 2 files changed, 150 deletions(-) delete mode 100644 examples/gen_sid_rqvae_mock_data.py delete mode 100644 examples/sid_rqvae_gumbel_clip_local.config diff --git a/examples/gen_sid_rqvae_mock_data.py b/examples/gen_sid_rqvae_mock_data.py deleted file mode 100644 index 8f753c9d0..000000000 --- a/examples/gen_sid_rqvae_mock_data.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Write a mock embedding parquet for the SidRqvae Gumbel+CLIP smoke config. - -Columns match ``examples/sid_rqvae_gumbel_clip_local.config``: - item1_embedding (list[dim]) -- the SID input embedding - item2_embedding (list[dim]) -- the CLIP-paired embedding - is_contrastive (float32 scalar) -- 1.0 = CLIP pair, 0.0 = recon-only - -Usage: - python examples/gen_sid_rqvae_mock_data.py --out_dir ./tmp/sid_rqvae_mock \ - --num_rows 4096 --dim 512 --clip_ratio 0.5 -""" - -import argparse -import os - -import numpy as np -import pyarrow as pa -import pyarrow.parquet as pq - - -def main() -> None: - """Generate the mock parquet shard.""" - parser = argparse.ArgumentParser() - parser.add_argument("--out_dir", default="./tmp/sid_rqvae_mock") - parser.add_argument("--num_rows", type=int, default=4096) - parser.add_argument("--dim", type=int, default=512) - parser.add_argument( - "--clip_ratio", - type=float, - default=0.5, - help="fraction of rows flagged as CLIP pairs (is_contrastive=1)", - ) - parser.add_argument("--seed", type=int, default=0) - args = parser.parse_args() - - rng = np.random.default_rng(args.seed) - item1 = rng.standard_normal((args.num_rows, args.dim)).astype(np.float32) - # item2 is a noisy view of item1 so the contrastive pairs are learnable. - item2 = (item1 + 0.1 * rng.standard_normal(item1.shape)).astype(np.float32) - is_clip = (rng.random(args.num_rows) < args.clip_ratio).astype(np.float32) - - os.makedirs(args.out_dir, exist_ok=True) - out_path = os.path.join(args.out_dir, "part-0.parquet") - pq.write_table( - pa.table( - { - "item1_embedding": pa.array(list(item1)), - "item2_embedding": pa.array(list(item2)), - "is_contrastive": pa.array(is_clip), - } - ), - out_path, - ) - print( - f"wrote {args.num_rows} rows (dim={args.dim}, " - f"clip_pairs={int(is_clip.sum())}) -> {out_path}" - ) - - -if __name__ == "__main__": - main() diff --git a/examples/sid_rqvae_gumbel_clip_local.config b/examples/sid_rqvae_gumbel_clip_local.config deleted file mode 100644 index ed24b9ce3..000000000 --- a/examples/sid_rqvae_gumbel_clip_local.config +++ /dev/null @@ -1,78 +0,0 @@ -# SidRqvae end-to-end smoke config — Gumbel-Softmax forward mode + CLIP. -# -# Ported from ft_scripts/sid_rqvae_clip_8192.feat_abstract.config to the -# feat/sid_abstract schema (repeated uint32 / float as one value per line), -# downsized for a single-box CPU run, and switched to Gumbel-Softmax: -# forward_mode: "ste" -> "gumbel_softmax" -# Notes on Gumbel: -# * kmeans_init stays false: Gumbel trains the codebook by gradient, so it -# needs no FAISS warm-start (and thus no faiss at init time). -# * Sinkhorn is auto-disabled under Gumbel by SidRqvae; no sinkhorn_config -# block is needed. -# Generate the matching mock parquet first: -# python examples/gen_sid_rqvae_mock_data.py --out_dir ./tmp/sid_rqvae_mock -model_dir: "experiments/sid_rqvae_gumbel_clip_local" - -train_config { - sparse_optimizer { - adam_optimizer { lr: 0.002 beta1: 0.9 beta2: 0.999 weight_decay: 0.0001 } - constant_learning_rate {} - } - dense_optimizer { - adamw_optimizer { lr: 0.002 beta1: 0.9 beta2: 0.999 weight_decay: 0.0001 } - constant_learning_rate {} - } - num_epochs: 2 - save_checkpoints_steps: 100 - log_step_count_steps: 5 - is_profiling: false -} - -eval_config {} - -data_config { - batch_size: 256 - dataset_type: ParquetDataset - fg_mode: FG_DAG - num_workers: 2 - drop_remainder: true -} - -feature_configs { - raw_feature { feature_name: "item1_emb" expression: "item:item1_embedding" value_dim: 512 } -} -feature_configs { - raw_feature { feature_name: "item2_emb" expression: "item:item2_embedding" value_dim: 512 } -} -feature_configs { - raw_feature { feature_name: "is_clip_pair" expression: "item:is_contrastive" value_dim: 1 } -} - -model_config { - feature_groups { - group_name: "deep" - feature_names: "item1_emb" - feature_names: "item2_emb" - feature_names: "is_clip_pair" - group_type: DEEP - } - sid_rqvae { - input_dim: 512 - embed_dim: 64 - hidden_dims: 256 - hidden_dims: 256 - codebook: 256 - codebook: 256 - codebook: 256 - forward_mode: "gumbel_softmax" - loss_type: "mse" - kmeans_init: false - latent_weight: 0.5 - latent_weight: 0.5 - embedding_feature_name: "item1_emb" - clip_config { - clip_feature_name: "item2_emb" - is_clip_pair_feature_name: "is_clip_pair" - } - } -} From 8244af2a9161ecc89747903bbbcd9a4b9f84db14 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 12:04:15 +0000 Subject: [PATCH 094/129] [refactor] SID: post-review cleanup + RQ-VAE robustness fixes Simplification: - inline the one-line _gumbel_softmax_sample wrapper into F.gumbel_softmax - _find_nearest_embedding returns ids only (its distances return was always discarded); drop the now-unused Tuple import Robustness (from the full PR code-review): - RVQ.init_embed_: broadcast rank-0's "enough rows" verdict before the rank-divergent FAISS fit, so all ranks raise together on a too-small first batch instead of deadlocking on the centroid broadcast - KMeansQuantizeLayer.quantize: cast x to the centroid dtype before cdist (consistent with load_centroids_) - SidRqvae.__init__: fail-fast on embed_dim < 1 / hidden_dims entry < 1 (parity with BaseSidModel's codebook/input_dim checks) Docs: - faiss_kmeans_fit keeps the gpu-kwarg strip (faiss honors gpu, so it is the CPU-only guard, not dead code); clarified why - sid_model.proto: document kmeans_init as a best-effort single-batch warm-start needing batch_size >= max(codebook) Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 6 +++ tzrec/modules/sid/kmeans_quantize.py | 13 +++++-- .../modules/sid/residual_kmeans_quantizer.py | 9 +++-- .../modules/sid/residual_vector_quantizer.py | 14 +++++++ tzrec/modules/sid/vector_quantize.py | 38 ++++--------------- tzrec/protos/models/sid_model.proto | 15 ++++++-- 6 files changed, 53 insertions(+), 42 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 2caf1efa7..79513f8d0 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -88,9 +88,15 @@ def __init__( ) embed_dim = cfg.embed_dim + # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero + # dim only errors opaquely deep in nn.Linear/Embedding otherwise. + if embed_dim < 1: + raise ValueError(f"embed_dim must be >= 1, got {embed_dim}") hidden_dims = ( list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] ) + if any(h < 1 for h in hidden_dims): + raise ValueError(f"every hidden_dims entry must be >= 1, got {hidden_dims}") # Empty -> default (1.0, 0.5); the quantizer validates the arity. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 36a95e435..2641e6963 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -42,9 +42,10 @@ def faiss_kmeans_fit( The shared one-layer FAISS fit behind both SID residual-K-Means loops (the RQ-VAE warm-start and the offline RQ-K-Means); the caller reads - ``km.centroids`` and assigns via ``km.index.search``. Strips a stale ``gpu`` - kwarg and guards ``N >= n_clusters`` before faiss's opaque C++ throw. ``x`` - may be a numpy array or a torch tensor. + ``km.centroids`` and assigns via ``km.index.search``. Strips a ``gpu`` kwarg + (faiss honors it and would move the fit to GPU, breaking the CPU-only + contract) and guards ``N >= n_clusters`` before faiss's opaque C++ throw. + ``x`` may be a numpy array or a torch tensor. Args: x: data points, shape (N, dim) — numpy array or torch tensor. @@ -67,6 +68,8 @@ def faiss_kmeans_fit( "`pip install faiss-cpu` or `pip install faiss-gpu`." ) from e + # Copy + drop any `gpu` key: faiss.Kmeans honors it and would move the fit + # to GPU, breaking the CPU-only contract (the caller's dict stays untouched). kwargs = dict(faiss_kmeans_kwargs or {}) kwargs.pop("gpu", None) n = int(x.shape[0]) @@ -271,7 +274,9 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: if not self.is_initialized: ids = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) return QuantizeOutput(embeddings=torch.zeros_like(x), ids=ids) - ids = torch.cdist(x, self.centroids).argmin(dim=-1) + # Match x to the centroid dtype (as load_centroids_ does): cdist rejects + # mismatched dtypes, so a non-fp32 input would otherwise raise. + ids = torch.cdist(x.to(self.centroids.dtype), self.centroids).argmin(dim=-1) return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 2ebee0ec2..bc7e64300 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -201,9 +201,7 @@ def train_offline( # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit (SidRqkmeans refuses CUDA). The ``gpu`` kwarg is stripped - # inside faiss_kmeans_fit. - kwargs = dict(self.faiss_kmeans_kwargs) + # CPU-only fit (SidRqkmeans refuses CUDA). if verbose: logger.info( "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " @@ -223,7 +221,10 @@ def train_offline( # Fresh Kmeans per layer so each can use its own K (non-uniform # codebooks). km = faiss_kmeans_fit( - x, self.embed_dim, self.n_embed_list[layer_idx], kwargs + x, + self.embed_dim, + self.n_embed_list[layer_idx], + self.faiss_kmeans_kwargs, ) centroids = torch.as_tensor(km.centroids, dtype=torch.float32) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index d3c58c285..cc9c23c87 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -226,6 +226,20 @@ def init_embed_(self, data: torch.Tensor) -> None: return is_ddp = dist.is_initialized() and dist.get_world_size() > 1 + # The fit runs on rank 0 only, then broadcasts. faiss needs N >= max(K), + # so a too-small rank-0 first batch would raise on rank 0 while the other + # ranks block forever on the centroid broadcast. Broadcast rank 0's verdict + # first so every rank aborts together with a clear error instead. + max_k = max(self.n_embed_list) + enough = torch.tensor([1 if data.shape[0] >= max_k else 0], device=data.device) + if is_ddp: + dist.broadcast(enough, src=0) + if enough.item() == 0: + raise RuntimeError( + f"kmeans_init: rank-0 first training batch has fewer rows than the " + f"largest codebook ({max_k}); raise batch_size or disable kmeans_init." + ) + if (not is_ddp) or dist.get_rank() == 0: centers = faiss_residual_kmeans( data, diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index aeae0f109..3ccdb042f 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -11,8 +11,6 @@ """Single codebook vector quantization with Sinkhorn uniform assignment.""" -from typing import Tuple - import torch import torch.distributed as dist from torch import nn @@ -43,24 +41,6 @@ def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso return (x_sq + y_sq - 2.0 * x @ y.t()).clamp(min=0.0) -def _gumbel_softmax_sample( - logits: torch.Tensor, - temperature: float = 1.0, - hard: bool = True, -) -> torch.Tensor: - """Sample from the Gumbel-Softmax distribution. - - Args: - logits (Tensor): un-normalized log probabilities, shape (B, N). - temperature (float): temperature for Gumbel-Softmax. - hard (bool): if True, return one-hot with straight-through gradient. - - Returns: - Tensor: soft or hard sample, shape (B, N). - """ - return F.gumbel_softmax(logits, tau=temperature, hard=hard, dim=-1) - - @torch.no_grad() def _sinkhorn( cost: torch.Tensor, @@ -180,7 +160,8 @@ def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: """Compute L2/cosine distances between inputs and codebook entries. Not ``no_grad``: Gumbel calls this directly for the encoder gradient; - the STE/Sinkhorn path wraps it via ``no_grad`` :meth:`_find_nearest_embedding`. + the STE/Sinkhorn path calls it inside ``no_grad`` in + :meth:`_find_nearest_embedding`. Args: x (Tensor): input vectors, shape (B, D). @@ -204,10 +185,8 @@ def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: return distances @torch.no_grad() - def _find_nearest_embedding( - self, x: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Find nearest codebook entry for each input vector. + def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: + """Find the nearest codebook id for each input vector. During training with use_sinkhorn=True, applies z-score normalization + non-negative shift before Sinkhorn assignment. @@ -217,8 +196,7 @@ def _find_nearest_embedding( x (Tensor): input vectors, shape (B, D). Returns: - ids (Tensor): codebook indices, shape (B,). - distances (Tensor): distance matrix, shape (B, n_embed). + Tensor: codebook indices, shape (B,). """ distances = self._compute_distances(x) # (B, n_embed) @@ -239,7 +217,7 @@ def _find_nearest_embedding( else: ids = distances.argmin(dim=-1) - return ids, distances + return ids def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). @@ -259,13 +237,13 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: # both emb and ids, so the saved code matches the vector used. if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: logits = -self._compute_distances(x) # (B, n_embed), differentiable - weights = _gumbel_softmax_sample(logits, temperature=temperature, hard=True) + weights = F.gumbel_softmax(logits, tau=temperature, hard=True, dim=-1) emb = weights @ self.embedding.weight ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) # STE / eval: nearest-neighbour assignment under no_grad. - ids, _ = self._find_nearest_embedding(x) + ids = self._find_nearest_embedding(x) if self.training: quantized = self.embedding(ids) # straight-through: grad passes to x emb = x + (quantized - x).detach() diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index d160a57fc..102a16c84 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -67,10 +67,17 @@ message SidRqvae { repeated float latent_weight = 11; // STE rotation trick. optional bool rotation_trick = 12 [default = false]; - // KMeans codebook initialization on first training forward. Default false: - // the FAISS warm-start needs the first batch to have >= max(codebook) rows - // (faiss requires N >= K) and, under DDP, a rank-0 fit failure would hang - // the other ranks on the centroid broadcast — so it is opt-in. + // KMeans codebook initialization on first training forward. Default false. + // Best-effort warm-start only: it seeds the codebook from a SINGLE batch + // (the encoder is still near-random at step 1), and gradient training + + // commitment loss refine it afterward. Requirements/caveats: + // * batch_size >= max(codebook) (FAISS requires N >= K) — so it is + // unusable with large codebooks at typical batch sizes (e.g. an 8192 + // codebook needs a >= 8192-row first batch); under DDP a too-small + // rank-0 batch now raises on all ranks instead of hanging. + // * one batch is statistically thin (few points per centroid); for a + // data-rich fit use the SidRqkmeans model (reservoir-sampled) instead. + // Opt-in for these reasons. optional bool kmeans_init = 13 [default = false]; // === Optional sub-module configs === From cb39ca2cc0b939bdefb16aab4264665f4bc47281 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 12:27:47 +0000 Subject: [PATCH 095/129] [simplify] SID: drop dead is_distributed param from _sinkhorn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The sole caller passed is_distributed=dist.is_initialized(), and every guard inside was `is_distributed and dist.is_initialized()` — so the param was always identical to dist.is_initialized(). Replace the three guards with a plain `if dist.is_initialized():` and drop the parameter. Bit-identical behavior. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/vector_quantize.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 3ccdb042f..d998774d9 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -46,13 +46,13 @@ def _sinkhorn( cost: torch.Tensor, n_iters: int = 5, epsilon: float = 10.0, - is_distributed: bool = True, ) -> torch.Tensor: """Sinkhorn-Knopp algorithm for optimal-transport based uniform assignment. Transforms a distance matrix into a soft assignment matrix via exponential kernel and alternating row-column normalization, approximating a doubly - stochastic matrix to ensure uniform codebook utilization. + stochastic matrix to ensure uniform codebook utilization. Row sums are + all-reduced across ranks when a process group is initialized. Args: cost (Tensor): distance matrix, shape (B, K) where K is codebook size. @@ -61,8 +61,6 @@ def _sinkhorn( n_iters (int): number of Sinkhorn iterations. Default: 5. epsilon (float): sharpness parameter for exp(-cost * epsilon). Larger values produce sharper assignments. Default: 10.0. - is_distributed (bool): whether running in distributed mode. - If True, row sums are all_reduced across GPUs. Default: True. Returns: Tensor: assignment matrix, shape (B, K). @@ -72,7 +70,7 @@ def _sinkhorn( Q = torch.exp(-cost * epsilon).t() # Global batch size for distributed training - if is_distributed and dist.is_initialized(): + if dist.is_initialized(): B = Q.size(1) * dist.get_world_size() else: B = Q.size(1) @@ -80,7 +78,7 @@ def _sinkhorn( # Step 2: global normalization — make matrix sum to 1 sum_Q = torch.sum(Q) - if is_distributed and dist.is_initialized(): + if dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q + 1e-8 @@ -88,7 +86,7 @@ def _sinkhorn( for _ in range(n_iters): # Row normalization: each prototype's total weight = 1/K sum_of_rows = torch.sum(Q, dim=1, keepdim=True) - if is_distributed and dist.is_initialized(): + if dist.is_initialized(): dist.all_reduce(sum_of_rows) Q /= sum_of_rows + 1e-8 Q /= K @@ -211,7 +209,6 @@ def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: distances, n_iters=self.sinkhorn_iters, epsilon=self.sinkhorn_epsilon, - is_distributed=dist.is_initialized(), ) ids = Q.argmax(dim=-1) else: From 9cf7a629f262dc55077c826f0f690d3eca73a1d9 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 15 Jun 2026 13:14:15 +0000 Subject: [PATCH 096/129] =?UTF-8?q?[fix]=20SID:=20address=20PR=20review=20?= =?UTF-8?q?=E2=80=94=20logit=5Fscale=20clamp,=20trim=20CLIP=20outputs,=20d?= =?UTF-8?q?edup=20gather=20+=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bot-review fixes: - SidRqvae: clamp logit_scale to ln(100) before exp() — unbounded exp() -> +Inf -> a NaN grad that permanently corrupts the param (the loss-side nan_to_num only sanitizes the CE output, not the parameter). - MaskedCLIPLoss: drop the unused clip_acc / loss_self/ori/cl return keys and the eval-only accuracy block (no consumer); returns {clip_loss}. - VectorQuantize.quantize: add apply_ste flag; the RVQ caller passes False to take the raw codebook vector in a single gather instead of building + discarding an STE wrap and re-gathering. Bit-identical training. - sid_model.proto: fix the stale RQVAE reference in the latent_weight comment. Test coverage (gaps flagged in the review summary): - Sinkhorn functionally balances clustered points (vs argmin collapse). - rotation_trick=True Householder STE branch trains (grad reaches the input). - loss_type l1/cosine and commitment_loss cos branches run end-to-end. - CLIP masked-column negatives: perturbing recon rows leaves clip-row loss intact. - kmeans_init "rank-0 batch too small" abort raises a clear error. #2 (all_gather batch-alignment) intentionally NOT changed: batch/collective alignment is the framework's responsibility (rebalance / drop_redundant_bs_eq_one / TorchRec TrainPipelineSparseDist), not a model-level guard. ruff clean; SID suite green (vector_quantize 11, residual_vector_quantizer 7, clip_loss 7, sid_rqvae 16, residual_quantizer 22, kmeans_quantize 12, sid_rqkmeans 18). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss.py | 29 +----- tzrec/loss/clip_loss_test.py | 39 ++++++++ tzrec/models/sid_rqvae.py | 12 ++- tzrec/models/sid_rqvae_test.py | 95 +++++++++++++++++++ .../modules/sid/residual_vector_quantizer.py | 15 +-- .../sid/residual_vector_quantizer_test.py | 52 ++++++++++ tzrec/modules/sid/vector_quantize.py | 24 +++-- tzrec/modules/sid/vector_quantize_test.py | 23 +++++ tzrec/protos/models/sid_model.proto | 2 +- 9 files changed, 244 insertions(+), 47 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index 5f4a0b56a..5df65b880 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -37,11 +37,7 @@ class MaskedCLIPLoss(_Loss): 'logit_scale': scalar original feature contrast temperature Output dict keys: - 'clip_loss': scalar mean of three losses (self/ori/cl) - 'clip_acc': scalar contrast accuracy (%); 0 during training - 'loss_self': scalar quantized vs quantized - 'loss_ori': scalar quantized vs original - 'loss_cl': scalar quantized vs counterpart original + 'clip_loss': scalar mean of three contrastive losses (self/ori/cl) """ def __init__(self) -> None: @@ -188,25 +184,4 @@ def forward( clip_loss = (loss_self + loss_ori + loss_cl) / 3 - # Retrieval accuracy is diagnostic-only; skip the four argmax+eq+sum - # reductions during training (recover via the eval pass). - if self.training: - acc = torch.zeros((), device=clip_loss.device) - else: - with torch.no_grad(): - n_valid = clip_mask.float().sum().clamp(min=1) - correct = ( - (logits_img_self.argmax(-1).eq(safe_labels) & clip_mask).sum() - + (logits_txt_self.argmax(-1).eq(safe_labels) & clip_mask).sum() - + (logits_img_ori.argmax(-1).eq(safe_labels) & clip_mask).sum() - + (logits_txt_ori.argmax(-1).eq(safe_labels) & clip_mask).sum() - ) - acc = 100 * correct / (n_valid * 4) - - return { - "clip_loss": clip_loss, - "clip_acc": acc, - "loss_self": loss_self, - "loss_ori": loss_ori, - "loss_cl": loss_cl, - } + return {"clip_loss": clip_loss} diff --git a/tzrec/loss/clip_loss_test.py b/tzrec/loss/clip_loss_test.py index cbd30691f..8ba9a2c6b 100644 --- a/tzrec/loss/clip_loss_test.py +++ b/tzrec/loss/clip_loss_test.py @@ -80,6 +80,45 @@ def test_backward_flows_to_embeddings(self) -> None: self.assertIsNotNone(feats["image_embed"].grad) self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + def test_recon_columns_excluded_from_negatives(self) -> None: + """A recon row's embedding must not affect a clip row's loss. + + Recon rows are dropped as queries (row mask) AND their columns are + masked out of the negatives (col_mask). Perturbing the recon rows of + EVERY column operand — ``text_embed`` (the self group) and both + ``*_ori`` operands (the ori/cl groups) — must leave the clip rows' loss + unchanged; a dropped or inverted ``col_mask`` on any group would fail. + Distinct ``image_embed_ori`` / ``text_embed_ori`` so the ori/cl masking + is actually exercised (not hidden by a shared tensor). + """ + torch.manual_seed(0) + B, D = 4, 8 + img = torch.randn(B, D) + scale = torch.tensor(10.0) + mask = torch.tensor([True, True, False, False]) # rows 2,3 are recon + + def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): + return { + "image_embed": img, + "text_embed": txt, + "image_embed_ori": img_ori, + "text_embed_ori": txt_ori, + "logit_scale_self": scale, + "logit_scale_cl": scale, + "logit_scale": scale, + } + + txt, txt_ori, img_ori = (torch.randn(B, D) for _ in range(3)) + loss_fn = MaskedCLIPLoss() + loss_fn.eval() + base = loss_fn(feats(txt, txt_ori, img_ori), mask)["clip_loss"] + # Perturb ONLY the recon rows of every column operand that feeds negatives. + txt2, txt_ori2, img_ori2 = txt.clone(), txt_ori.clone(), img_ori.clone() + for t in (txt2, txt_ori2, img_ori2): + t[2:] = torch.randn(2, D) + after = loss_fn(feats(txt2, txt_ori2, img_ori2), mask)["clip_loss"] + torch.testing.assert_close(base, after) + def test_mask_holds_under_large_scale(self) -> None: # The column fill is finfo.min (below any real logit) rather than a # hardcoded -1e4, so masking holds even when logit_scale is large and diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 79513f8d0..3491dbd07 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -35,6 +35,10 @@ from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils.logging_util import logger +# Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): +# an unbounded ``logit_scale`` overflows to +Inf -> NaN grad -> corrupt param. +_LOGIT_SCALE_MAX = float(np.log(100)) + class SidRqvae(BaseSidModel): """SID generation model using RQ-VAE (Encoder + VQ + Decoder). @@ -229,9 +233,11 @@ def _forward_mixed( "text_embed": x_hat2, "image_embed_ori": fea1, "text_embed_ori": fea2, - "logit_scale_self": self._logit_scale_self.exp(), - "logit_scale_cl": self._logit_scale_cl.exp(), - "logit_scale": self._logit_scale.exp(), + "logit_scale_self": self._logit_scale_self.clamp( + max=_LOGIT_SCALE_MAX + ).exp(), + "logit_scale_cl": self._logit_scale_cl.clamp(max=_LOGIT_SCALE_MAX).exp(), + "logit_scale": self._logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp(), } clip_result = self._masked_clip_loss_fn(features, clip_mask) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index d3887be6c..5d62aeaf9 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -401,6 +401,101 @@ def test_commitment_loss_invalid_raises(self) -> None: use_sinkhorn=False, ) + def test_loss_type_l1_and_cosine(self) -> None: + """loss_type 'l1' and 'cosine' recon branches run end-to-end. + + Only 'mse' was previously exercised; a typo'd branch would have been + silent. + """ + B, input_dim = 4, 32 + for loss_type in ("l1", "cosine"): + cfg = sid_model_pb2.SidRqvae( + input_dim=input_dim, + embed_dim=8, + codebook=[16, 16], + forward_mode="ste", + loss_type=loss_type, + kmeans_init=False, + embedding_feature_name="item_emb", + ) + model = SidRqvae( + model_config=model_pb2.ModelConfig(sid_rqvae=cfg), + features=[], + labels=[], + ) + init_parameters(model, device=torch.device("cpu")) + model.train() + model.init_loss() + preds = model.predict(_make_batch(B, input_dim)) + recon = preds["reconstruction_loss"] + self.assertTrue(torch.isfinite(recon), f"{loss_type} recon not finite") + recon.backward() # grad must flow through the decoder + + def test_commitment_loss_cos_branch(self) -> None: + """Verify the commitment_loss='cos' branch runs end-to-end.""" + from tzrec.modules.sid.residual_vector_quantizer import ( + ResidualVectorQuantizer, + ) + + torch.manual_seed(0) + rq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=4, + forward_mode="ste", + commitment_loss="cos", + kmeans_init=False, + use_sinkhorn=False, + ) + for layer in rq.layers: + torch.nn.init.normal_(layer.embedding.weight, std=0.1) + x = torch.randn(4, 8, requires_grad=True) + out = rq(x) + self.assertTrue(torch.isfinite(out.quantization_loss)) + out.quantization_loss.backward() + self.assertIsNotNone(x.grad) + + def test_logit_scale_clamped_prevents_overflow(self) -> None: + """A raw logit_scale far above ln(100) must not overflow. + + The clamp caps ``exp()`` so the CLIP loss and the parameter gradient + stay finite; without it, ``exp(large)`` -> +Inf -> a NaN gradient that + permanently corrupts the parameter. + """ + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.train() + model.init_loss() + with torch.no_grad(): + model._logit_scale_self.fill_(100.0) + model._logit_scale_cl.fill_(100.0) + model._logit_scale.fill_(100.0) + + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[ + torch.randn(B, input_dim), + torch.randn(B, input_dim), + torch.ones(B, 1), + ], + ) + }, + sparse_features={}, + labels={}, + ) + losses = model.loss(model.predict(batch), batch) + self.assertTrue(torch.isfinite(losses["clip_loss"])) + sum(losses.values()).backward() + for p in ( + model._logit_scale_self, + model._logit_scale_cl, + model._logit_scale, + ): + self.assertIsNotNone(p.grad) + self.assertTrue(torch.isfinite(p.grad).all()) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index cc9c23c87..754fee9bf 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -367,14 +367,15 @@ def _quantize_layer( Returns: ids (Tensor): per-layer cluster ids, shape (B,). - raw_emb (Tensor): raw codebook vectors (with grad), shape (B, D). + emb (Tensor): the raw codebook vector (STE/eval) or the soft + embedding (Gumbel), with grad, shape (B, D). """ - layer = self.layers[layer_idx] - out = layer.quantize(residual, temperature) - if self._train_gumbel(): - return out.ids, out.embeddings # soft embedding carries grad - # STE / eval: raw codebook vector; STE applied on the aggregate in forward. - return out.ids, layer.lookup(out.ids) + # apply_ste=False: Gumbel ignores it (returns the soft embedding that + # carries grad); STE/eval get the raw codebook vector in one gather (STE + # is applied on the aggregate in :meth:`forward`), avoiding the discarded + # per-layer straight-through wrap + a second codebook gather. + out = self.layers[layer_idx].quantize(residual, temperature, apply_ste=False) + return out.ids, out.embeddings def forward( self, diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 1c71936bf..d196f4e2e 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -107,5 +107,57 @@ def test_raises_on_too_few_points(self) -> None: faiss_residual_kmeans(torch.randn(4, 6), [8]) +class ResidualVQBranchTest(unittest.TestCase): + """Coverage for the rotation-trick STE branch and the kmeans-init guard.""" + + def test_rotation_trick_rotates_gradient(self) -> None: + # The rotation trick keeps the STE forward value but ROTATES the input + # gradient. Plain STE also yields a finite non-zero grad, so a regression + # that reverted the Householder branch to ordinary STE would pass a smoke + # test. Pin the distinguishing property: same forward output, different + # input gradient. + def run(rotation_trick: bool): + torch.manual_seed(0) # identical codebook init + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + rotation_trick=rotation_trick, + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + torch.manual_seed(1) # identical input + z = torch.randn(16, 8, requires_grad=True) + out = rvq(z).quantized_embeddings + out.sum().backward() + return out.detach(), z.grad + + out_rot, grad_rot = run(True) + out_ste, grad_ste = run(False) + self.assertTrue(torch.isfinite(grad_rot).all()) + self.assertGreater(grad_rot.abs().sum().item(), 0.0) + # Forward value is identical (the trick only changes the backward). + torch.testing.assert_close(out_rot, out_ste) + # ...but it genuinely rotates the gradient, unlike plain STE. + self.assertFalse(torch.allclose(grad_rot, grad_ste)) + + def test_kmeans_init_too_small_batch_raises(self) -> None: + # kmeans_init needs N >= max(codebook). A too-small first batch must + # raise a clear error (broadcast so all ranks abort together under DDP), + # not hang the non-rank-0 ranks on the centroid broadcast. + rvq = ResidualVectorQuantizer( + embed_dim=4, + n_layers=2, + n_embed=8, + kmeans_init=True, + use_sinkhorn=False, + ) + rvq.train() + with self.assertRaisesRegex(RuntimeError, "fewer rows than the largest"): + rvq(torch.randn(4, 4)) # 4 < max(codebook)=8 + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index d998774d9..20c893490 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -216,7 +216,9 @@ def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: return ids - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize( + self, x: torch.Tensor, temperature: float = 1.0, apply_ste: bool = True + ) -> QuantizeOutput: """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). Commitment loss is computed by the caller @@ -226,6 +228,11 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: Args: x (Tensor): input vectors, shape (B, D). temperature (float): temperature for Gumbel-Softmax. + apply_ste (bool): in STE training, wrap the embedding with the + straight-through estimator (``x + (q - x).detach()``). Set False + when the caller re-applies STE on the aggregate + (:class:`ResidualVectorQuantizer`): the raw codebook vector is + returned and the otherwise-discarded wrap + re-gather is avoided. Returns: QuantizeOutput: named tuple of (embeddings, ids). @@ -239,15 +246,14 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) - # STE / eval: nearest-neighbour assignment under no_grad. + # STE / eval: nearest-neighbour assignment under no_grad, one codebook + # gather. STE wrap only when the caller wants it (standalone use); the + # RVQ caller passes apply_ste=False and re-applies STE on the aggregate. ids = self._find_nearest_embedding(x) - if self.training: - quantized = self.embedding(ids) # straight-through: grad passes to x - emb = x + (quantized - x).detach() - else: - emb = self.embedding(ids) - - return QuantizeOutput(embeddings=emb, ids=ids) + quantized = self.embedding(ids) + if self.training and apply_ste: + quantized = x + (quantized - x).detach() # grad passes to x + return QuantizeOutput(embeddings=quantized, ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: """Return the codebook table, shape (n_embed, embed_dim).""" diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index b6a05f269..31cc803b7 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -62,6 +62,29 @@ def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: self.assertTrue((out.ids >= 0).all() and (out.ids < 16).all()) self.assertTrue(torch.isfinite(out.embeddings).all()) + def test_sinkhorn_balances_assignment(self) -> None: + """Sinkhorn spreads clustered points across codes; argmin collapses them. + + Functional check (not just shape/finiteness): feed points clustered at + one anchor — argmin sends all to that code, while Sinkhorn's uniform + assignment must use more than one code. + """ + torch.manual_seed(0) + vq = VectorQuantize( + embed_dim=2, n_embed=4, use_sinkhorn=True, sinkhorn_iters=10 + ) + vq.train() + with torch.no_grad(): + vq.embedding.weight.copy_( + torch.tensor([[0.0, 0.0], [10.0, 0.0], [0.0, 10.0], [10.0, 10.0]]) + ) + x = torch.randn(16, 2) * 0.1 # all clustered at anchor 0 + sinkhorn_ids = vq.quantize(x).ids + vq.use_sinkhorn = False + argmin_ids = vq.quantize(x).ids + self.assertEqual(argmin_ids.unique().numel(), 1) + self.assertGreater(sinkhorn_ids.unique().numel(), 1) + def test_sinkhorn_gumbel_combo_rejected(self) -> None: """Sinkhorn + Gumbel would desync `ids` and `emb`; constructor rejects it.""" with self.assertRaisesRegex(AssertionError, "GUMBEL_SOFTMAX"): diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 102a16c84..fa17321f8 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -63,7 +63,7 @@ message SidRqvae { // Commitment loss type: "l2", "l1" or "cos". optional string commitment_loss = 10 [default = "l2"]; // Commitment loss weights [w1, w2]. Defaults to [1.0, 0.5] when unset - // (applied by RQVAE / ResidualVectorQuantizer). + // (applied by SidRqvae / ResidualVectorQuantizer). repeated float latent_weight = 11; // STE rotation trick. optional bool rotation_trick = 12 [default = false]; From 494267097dcd04cfcfc0159847b970d7cb4729b6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 16 Jun 2026 03:31:11 +0000 Subject: [PATCH 097/129] =?UTF-8?q?[fix]=20SID:=20address=202nd=20PR=20rev?= =?UTF-8?q?iew=20round=20=E2=80=94=20epsilon=20guard,=20doc=20fixes,=20nar?= =?UTF-8?q?row=20distance=5Ftype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - VectorQuantize: validate sinkhorn_epsilon > 0 at construction (a non-positive epsilon flips exp(-cost*eps) into +Inf -> NaN assignments) + test. - clip_loss docstring: image_embed/text_embed are the decoder reconstructions (B, input_dim), not the quantizer output; fix the logit_scale_* descriptions to match the actual contrasts (self = recon-1 vs recon-2; cl = recon vs same-feature original; logit_scale = recon vs counterpart original). - ResidualVectorQuantizer: narrow distance_type to str (drop the per-layer Union[str, List[str]] branch — unreachable from the scalar proto field). Skipped the init_embed_ per-step-sync item: training-only, sub-ms, usually hidden behind existing loss/metric syncs; the suggested buffer-mirror + _load_from_state_dict adds dual-state complexity not worth the payoff. ruff clean; SID suite green (vector_quantize 12). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss.py | 16 ++++++++-------- tzrec/modules/sid/residual_vector_quantizer.py | 14 +++----------- tzrec/modules/sid/vector_quantize.py | 4 ++++ tzrec/modules/sid/vector_quantize_test.py | 7 +++++++ 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py index 5df65b880..493ecf006 100644 --- a/tzrec/loss/clip_loss.py +++ b/tzrec/loss/clip_loss.py @@ -27,14 +27,14 @@ class MaskedCLIPLoss(_Loss): CLIP loss, and recon columns must not serve as negatives. Row/column masks achieve this without data-dependent branching (``torch.compile``-friendly). - Input dict keys: - 'image_embed': (B, D) quantized output of first feature - 'text_embed': (B, D) quantized output of second feature - 'image_embed_ori': (B, D) original embedding of first feature - 'text_embed_ori': (B, D) original embedding of second feature - 'logit_scale_self': scalar self-contrast temperature - 'logit_scale_cl': scalar cross-modal contrast temperature - 'logit_scale': scalar original feature contrast temperature + Input dict keys (all embeddings shape (B, input_dim)): + 'image_embed': reconstructed (decoder) output of feature 1 + 'text_embed': reconstructed (decoder) output of feature 2 + 'image_embed_ori': original embedding of feature 1 + 'text_embed_ori': original embedding of feature 2 + 'logit_scale_self': scalar temperature: recon-1 vs recon-2 + 'logit_scale_cl': scalar temperature: recon vs same-feature original + 'logit_scale': scalar temperature: recon vs counterpart original Output dict keys: 'clip_loss': scalar mean of three contrastive losses (self/ori/cl) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 754fee9bf..8af930fae 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -91,8 +91,7 @@ class ResidualVectorQuantizer(ResidualQuantizer): Default: 'ste'. normalize_residuals (bool): L2-normalize residuals before each quantization layer. Default: False. - distance_type (str|List[str]): distance metric per layer, - 'l2' or 'cosine'. Supports per-layer list. Default: 'l2'. + distance_type (str): distance metric, 'l2' or 'cosine'. Default: 'l2'. commitment_loss (str): commitment loss type, 'l2', 'l1' or 'cos'. Default: 'l2'. latent_weight (List[float]): commitment loss weights [w1, w2]. @@ -120,7 +119,7 @@ def __init__( n_embed: Union[int, List[int]] = 256, forward_mode: str = "ste", normalize_residuals: bool = False, - distance_type: Union[str, List[str]] = "l2", + distance_type: str = "l2", commitment_loss: str = "l2", latent_weight: Sequence[float] = (1.0, 0.5), rotation_trick: bool = False, @@ -164,14 +163,7 @@ def __init__( if is_gumbel and rotation_trick: logger.warning("gumbel_softmax: rotation_trick has no effect; ignoring.") - if isinstance(distance_type, str): - distance_types = [distance_type] * n_layers - else: - assert len(distance_type) == n_layers, ( - "length of distance_type and n_layers must be same, " - f"but got {len(distance_type)} vs {n_layers}" - ) - distance_types = list(distance_type) + distance_types = [distance_type] * n_layers self.layers = nn.ModuleList( [ diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 20c893490..80c25c726 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -144,6 +144,10 @@ def __init__( "`emb` (nearest code), so the returned id and embedding diverge. " "Use STE with Sinkhorn, or Gumbel-Softmax without Sinkhorn." ) + # epsilon sharpens exp(-cost * epsilon); <= 0 flips the kernel and the + # (large, shifted) cost overflows to +Inf -> NaN assignments. + if use_sinkhorn and sinkhorn_epsilon <= 0: + raise ValueError(f"sinkhorn_epsilon must be > 0, got {sinkhorn_epsilon}") # ``n_embed`` / ``embed_dim`` are owned by the QuantizeLayer base. self.forward_mode = forward_mode self.distance_type = distance_type diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index 31cc803b7..21c94d114 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -95,6 +95,13 @@ def test_sinkhorn_gumbel_combo_rejected(self) -> None: use_sinkhorn=True, ) + def test_sinkhorn_epsilon_must_be_positive(self) -> None: + """Reject a non-positive sinkhorn_epsilon (it overflows exp(-cost*eps)).""" + with self.assertRaisesRegex(ValueError, "sinkhorn_epsilon"): + VectorQuantize( + embed_dim=8, n_embed=16, use_sinkhorn=True, sinkhorn_epsilon=0.0 + ) + def test_train_forward_backward_reaches_input(self) -> None: torch.manual_seed(0) vq = VectorQuantize(embed_dim=8, n_embed=16, use_sinkhorn=False) From 22236343a13dc947c0d0ff562805c88cb74966d6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 16 Jun 2026 03:34:50 +0000 Subject: [PATCH 098/129] [chore] bump version 1.2.19 -> 1.2.20 Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index a5d3e8f5c..b9f275da8 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.19" +__version__ = "1.2.20" From fce258cf9a4cdd727a1f6de2d367d9e496852149 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 18 Jun 2026 03:39:15 +0000 Subject: [PATCH 099/129] [refactor] SidRqvae._predict_rqvae: gate on _is_inference directly `if self.is_train or self.is_eval:` is a roundabout `not self._is_inference` (is_train/is_eval both carry the not-inference term and the self.training term cancels). Use _is_inference directly with an inference early-return, matching _predict_mixed. Behavior-identical: inference emits codes only; train/eval add the recon/loss tensors. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 3491dbd07..77218ba86 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -271,18 +271,19 @@ def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: """Standard RQ-VAE: encode -> quantize -> decode -> loss.""" result = self._forward_rqvae(embedding) - predictions: Dict[str, torch.Tensor] = { + # Inference emits codes only (mirrors _predict_mixed); train/eval also + # carry the recon/loss tensors. + if self._is_inference: + return {"codes": result["codes"]} + + return { "codes": result["codes"], + "quantized": result["quantized"], + "x_hat": result["x_hat"], + "reconstruction_loss": result["reconstruction_loss"], + "quantization_loss": result["quantization_loss"], } - if self.is_train or self.is_eval: - predictions["quantized"] = result["quantized"] - predictions["x_hat"] = result["x_hat"] - predictions["reconstruction_loss"] = result["reconstruction_loss"] - predictions["quantization_loss"] = result["quantization_loss"] - - return predictions - def _predict_mixed( self, embedding: torch.Tensor, batch: Batch ) -> Dict[str, torch.Tensor]: From a752719d6c145e74bc874d1ebf18f668726b7b90 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 18 Jun 2026 03:56:26 +0000 Subject: [PATCH 100/129] [test] SID: fix test-class colocation (foo.py -> foo_test.py) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - residual_quantizer_test.py: delete ResidualKMeansQuantizerTest (a verbatim duplicate-name subset of the colocated residual_kmeans_quantizer_test.py copy, which is a strict superset) — clears the duplicate class name, zero coverage loss. Move ResidualVectorQuantizerTest (RVQ-specific: commitment-loss forward output + FAISS kmeans_init warm-start) to residual_vector_quantizer_test.py. Keep only the base-class tests (NormalizeNEmbed/ResidualQuantizerBase/_Fake Quantizer walk); prune now-dead imports. - sid_rqvae_test.py: move the 3 commitment_loss tests (they construct ResidualVectorQuantizer directly, not SidRqvae) into a CommitmentLossTest in residual_vector_quantizer_test.py. No behavior change; tests relocated to their colocated files. The *_dist_test.py files are left as-is (established multi-process convention). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae_test.py | 69 -------- tzrec/modules/sid/residual_quantizer_test.py | 151 ------------------ .../sid/residual_vector_quantizer_test.py | 112 +++++++++++++ 3 files changed, 112 insertions(+), 220 deletions(-) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 5d62aeaf9..77cd7c151 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -319,36 +319,6 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) self.assertGreater(predictions["clip_loss"].item(), 0.0) - def test_commitment_loss_l1_branch(self) -> None: - """Verify the new commitment_loss='l1' branch runs end-to-end. - - Previously ``"l1"`` silently fell through to the L2 branch. - """ - from tzrec.modules.sid.residual_vector_quantizer import ( - ResidualVectorQuantizer, - ) - - torch.manual_seed(0) - rq = ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - forward_mode="ste", - commitment_loss="l1", - kmeans_init=False, - use_sinkhorn=False, - ) - # Stub the codebook to known centroids so the result is reproducible. - for layer in rq.layers: - torch.nn.init.normal_(layer.embedding.weight, std=0.1) - - x = torch.randn(4, 8, requires_grad=True) - out = rq(x) - # Loss must be a finite scalar with gradient flowing back into x. - self.assertTrue(torch.isfinite(out.quantization_loss)) - out.quantization_loss.backward() - self.assertIsNotNone(x.grad) - def test_sinkhorn_config_enabled_false(self) -> None: """``sinkhorn_config { enabled: false }`` must turn Sinkhorn off. @@ -386,21 +356,6 @@ def test_sinkhorn_config_default_enabled(self) -> None: for layer in model._quantizer.layers: self.assertTrue(layer.use_sinkhorn) - def test_commitment_loss_invalid_raises(self) -> None: - """ResidualVectorQuantizer rejects unknown commitment_loss spellings.""" - from tzrec.modules.sid.residual_vector_quantizer import ( - ResidualVectorQuantizer, - ) - - with self.assertRaisesRegex(AssertionError, "commitment_loss"): - ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - commitment_loss="bogus", - use_sinkhorn=False, - ) - def test_loss_type_l1_and_cosine(self) -> None: """loss_type 'l1' and 'cosine' recon branches run end-to-end. @@ -431,30 +386,6 @@ def test_loss_type_l1_and_cosine(self) -> None: self.assertTrue(torch.isfinite(recon), f"{loss_type} recon not finite") recon.backward() # grad must flow through the decoder - def test_commitment_loss_cos_branch(self) -> None: - """Verify the commitment_loss='cos' branch runs end-to-end.""" - from tzrec.modules.sid.residual_vector_quantizer import ( - ResidualVectorQuantizer, - ) - - torch.manual_seed(0) - rq = ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - forward_mode="ste", - commitment_loss="cos", - kmeans_init=False, - use_sinkhorn=False, - ) - for layer in rq.layers: - torch.nn.init.normal_(layer.embedding.weight, std=0.1) - x = torch.randn(4, 8, requires_grad=True) - out = rq(x) - self.assertTrue(torch.isfinite(out.quantization_loss)) - out.quantization_loss.backward() - self.assertIsNotNone(x.grad) - def test_logit_scale_clamped_prevents_overflow(self) -> None: """A raw logit_scale far above ln(100) must not overflow. diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index d37ef7614..c94cc545d 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -14,17 +14,10 @@ import torch from torch import nn -from tzrec.modules.sid.residual_kmeans_quantizer import ( - ResidualKMeansQuantizer, -) from tzrec.modules.sid.residual_quantizer import ( ResidualQuantizer, normalize_n_embed, ) -from tzrec.modules.sid.residual_vector_quantizer import ( - ResidualVectorQuantizer, -) -from tzrec.modules.sid.types import ResidualQuantizerOutput class NormalizeNEmbedTest(unittest.TestCase): @@ -149,149 +142,5 @@ def test_decode_codes_sum_and_dtype(self) -> None: self.assertEqual(recon16.dtype, torch.bfloat16) -class ResidualVectorQuantizerTest(unittest.TestCase): - def setUp(self) -> None: - torch.manual_seed(0) - self.rvq = ResidualVectorQuantizer( - embed_dim=8, n_layers=3, n_embed=16, kmeans_init=False - ) - - def test_is_subclass(self) -> None: - self.assertIsInstance(self.rvq, ResidualQuantizer) - - def test_forward_output(self) -> None: - self.rvq.train() - out = self.rvq(torch.randn(5, 8)) - self.assertIsInstance(out, ResidualQuantizerOutput) - self.assertEqual(out.cluster_ids.shape, (5, 3)) - self.assertEqual(out.quantized_embeddings.shape, (5, 8)) - self.assertTrue(torch.isfinite(out.quantization_loss).all()) - - def test_decode_codes_shared_base(self) -> None: - codes = torch.randint(0, 16, (5, 3)) - recon = self.rvq.decode_codes(codes) - self.assertEqual(recon.shape, (5, 8)) - - def test_get_codes_no_grad(self) -> None: - codes = self.rvq.get_codes(torch.randn(4, 8)) - self.assertEqual(codes.shape, (4, 3)) - - def test_forward_get_codes_consistent_eval(self) -> None: - """get_codes (shared base walk) matches forward's ids in eval.""" - self.rvq.eval() - x = torch.randn(6, 8) - fwd_ids = self.rvq(x).cluster_ids - gc_ids = self.rvq.get_codes(x) - self.assertFalse(gc_ids.requires_grad) - torch.testing.assert_close(gc_ids, fwd_ids) - - def test_faiss_kmeans_init_seeds_codebook(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rvq = ResidualVectorQuantizer( - embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True - ) - self.assertFalse(bool(rvq.initted.item())) - rvq.train() - # First training forward triggers the FAISS warm-start. - rvq(torch.randn(512, 8)) - self.assertTrue(bool(rvq.initted.item())) - for layer in rvq.layers: - self.assertTrue(torch.isfinite(layer.embedding.weight).all()) - self.assertGreater(layer.embedding.weight.abs().sum().item(), 0.0) - - -class ResidualKMeansQuantizerTest(unittest.TestCase): - def test_is_subclass(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertIsInstance(rkq, ResidualQuantizer) - - def test_non_uniform_codebook_supported(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=3, n_embed=[8, 4, 16]) - self.assertEqual(rkq.n_embed_list, [8, 4, 16]) - self.assertEqual([layer.centroids.shape[0] for layer in rkq.layers], [8, 4, 16]) - - def test_forward_returns_zeros_before_fit(self) -> None: - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - self.assertFalse(all(layer.is_initialized for layer in rkq.layers)) - codes, quantized = rkq(torch.randn(5, 4)) - self.assertEqual(codes.shape, (5, 2)) - self.assertEqual(quantized.shape, (5, 4)) - - def test_forward_is_fx_traceable(self) -> None: - """Predict forward must FX-trace. - - torchrec's inference pipeline symbolically traces the model, so the - per-batch distance path must be free of data-dependent control flow. - """ - import torch.fx as fx - - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer(embed_dim=4, n_layers=2, n_embed=8) - for layer in rkq.layers: # populate centroids -> is_initialized=True - layer.load_centroids_(torch.randn(8, 4)) - traced = fx.symbolic_trace(rkq) - x = torch.randn(5, 4) - c_eager, q_eager = rkq(x) - c_traced, q_traced = traced(x) - torch.testing.assert_close(c_traced, c_eager) - torch.testing.assert_close(q_traced, q_eager) - - def test_train_offline_non_uniform(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - n_embed = [8, 4, 16] - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=n_embed, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(512, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - # Each layer fit its own K centroids; codes stay in per-layer range. - codes, _ = rkq(torch.randn(7, 4)) - self.assertEqual(codes.shape, (7, 3)) - for i, k in enumerate(n_embed): - self.assertTrue((codes[:, i] >= 0).all() and (codes[:, i] < k).all()) - - def test_train_offline_then_decode(self) -> None: - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=2, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - self.assertTrue(all(layer.is_initialized for layer in rkq.layers)) - - codes, _ = rkq(torch.randn(5, 4)) - self.assertTrue((codes >= 0).all() and (codes < 8).all()) - recon = rkq.decode_codes(codes) # inherited from the base - self.assertEqual(recon.shape, (5, 4)) - - def test_forward_get_codes_consistent(self) -> None: - """Forward ids and get_codes both route through the shared walk.""" - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - torch.manual_seed(0) - rkq = ResidualKMeansQuantizer( - embed_dim=4, n_layers=3, n_embed=8, faiss_kmeans_kwargs={"niter": 5} - ) - rkq.train_offline(torch.randn(256, 4), verbose=False) - x = torch.randn(9, 4) - fwd_ids, fwd_quant = rkq(x) - torch.testing.assert_close(rkq.get_codes(x), fwd_ids) - # forward's residual-sum equals the centroid-sum reconstruction. - torch.testing.assert_close(fwd_quant, rkq.decode_codes(fwd_ids)) - - if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index d196f4e2e..9953a2e05 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -13,10 +13,12 @@ import torch +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, faiss_residual_kmeans, ) +from tzrec.modules.sid.types import ResidualQuantizerOutput class GumbelResidualVQTest(unittest.TestCase): @@ -159,5 +161,115 @@ def test_kmeans_init_too_small_batch_raises(self) -> None: rvq(torch.randn(4, 4)) # 4 < max(codebook)=8 +class ResidualVectorQuantizerTest(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(0) + self.rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=3, n_embed=16, kmeans_init=False + ) + + def test_is_subclass(self) -> None: + self.assertIsInstance(self.rvq, ResidualQuantizer) + + def test_forward_output(self) -> None: + self.rvq.train() + out = self.rvq(torch.randn(5, 8)) + self.assertIsInstance(out, ResidualQuantizerOutput) + self.assertEqual(out.cluster_ids.shape, (5, 3)) + self.assertEqual(out.quantized_embeddings.shape, (5, 8)) + self.assertTrue(torch.isfinite(out.quantization_loss).all()) + + def test_decode_codes_shared_base(self) -> None: + codes = torch.randint(0, 16, (5, 3)) + recon = self.rvq.decode_codes(codes) + self.assertEqual(recon.shape, (5, 8)) + + def test_get_codes_no_grad(self) -> None: + codes = self.rvq.get_codes(torch.randn(4, 8)) + self.assertEqual(codes.shape, (4, 3)) + + def test_forward_get_codes_consistent_eval(self) -> None: + """get_codes (shared base walk) matches forward's ids in eval.""" + self.rvq.eval() + x = torch.randn(6, 8) + fwd_ids = self.rvq(x).cluster_ids + gc_ids = self.rvq.get_codes(x) + self.assertFalse(gc_ids.requires_grad) + torch.testing.assert_close(gc_ids, fwd_ids) + + def test_faiss_kmeans_init_seeds_codebook(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ) + self.assertFalse(bool(rvq.initted.item())) + rvq.train() + # First training forward triggers the FAISS warm-start. + rvq(torch.randn(512, 8)) + self.assertTrue(bool(rvq.initted.item())) + for layer in rvq.layers: + self.assertTrue(torch.isfinite(layer.embedding.weight).all()) + self.assertGreater(layer.embedding.weight.abs().sum().item(), 0.0) + + +class CommitmentLossTest(unittest.TestCase): + """ResidualVectorQuantizer commitment-loss branches (l1 / cos / invalid).""" + + def test_commitment_loss_l1_branch(self) -> None: + """The commitment_loss='l1' branch runs end-to-end (no fall-through to l2).""" + torch.manual_seed(0) + rq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=4, + forward_mode="ste", + commitment_loss="l1", + kmeans_init=False, + use_sinkhorn=False, + ) + for layer in rq.layers: + torch.nn.init.normal_(layer.embedding.weight, std=0.1) + x = torch.randn(4, 8, requires_grad=True) + out = rq(x) + self.assertTrue(torch.isfinite(out.quantization_loss)) + out.quantization_loss.backward() + self.assertIsNotNone(x.grad) + + def test_commitment_loss_cos_branch(self) -> None: + """The commitment_loss='cos' branch runs end-to-end.""" + torch.manual_seed(0) + rq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=4, + forward_mode="ste", + commitment_loss="cos", + kmeans_init=False, + use_sinkhorn=False, + ) + for layer in rq.layers: + torch.nn.init.normal_(layer.embedding.weight, std=0.1) + x = torch.randn(4, 8, requires_grad=True) + out = rq(x) + self.assertTrue(torch.isfinite(out.quantization_loss)) + out.quantization_loss.backward() + self.assertIsNotNone(x.grad) + + def test_commitment_loss_invalid_raises(self) -> None: + """ResidualVectorQuantizer rejects unknown commitment_loss spellings.""" + with self.assertRaisesRegex(AssertionError, "commitment_loss"): + ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=4, + commitment_loss="bogus", + use_sinkhorn=False, + ) + + if __name__ == "__main__": unittest.main() From b3f8105e316b4dd173e158ac104eb1719ff496c9 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 18 Jun 2026 08:12:45 +0000 Subject: [PATCH 101/129] [fix] SID: address PR #545 review rounds 2-3 - #3 use config_to_kwargs for sinkhorn kwargs (drop duplicated defaults) - #4 move CLIP loss/param init into init_loss() override - #6 remove SidRqvae train-metric overrides (inherit BaseSidModel no-op) - #7 parameterize near-duplicate sid_rqvae tests - #9 add end-to-end test_sid_rqvae_train_eval + sid_rqvae_mock.config - #11 rename VectorQuantize -> VectorQuantizeLayer - #1 merge *_dist_test.py into colocated *_test.py - C1 inline single-use _train_gumbel; C2 add follow-up TODO for multi-batch warm-start Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss_dist_test.py | 128 ------------------ tzrec/loss/clip_loss_test.py | 106 +++++++++++++++ tzrec/models/sid_rqvae.py | 66 +++------ tzrec/models/sid_rqvae_test.py | 92 ++++++------- .../modules/sid/residual_vector_quantizer.py | 17 ++- .../residual_vector_quantizer_dist_test.py | 91 ------------- .../sid/residual_vector_quantizer_test.py | 68 ++++++++++ tzrec/modules/sid/vector_quantize.py | 2 +- tzrec/modules/sid/vector_quantize_test.py | 22 +-- tzrec/tests/configs/sid_rqvae_mock.config | 55 ++++++++ tzrec/tests/sid_integration_test.py | 49 ++++++- 11 files changed, 353 insertions(+), 343 deletions(-) delete mode 100644 tzrec/loss/clip_loss_dist_test.py delete mode 100644 tzrec/modules/sid/residual_vector_quantizer_dist_test.py create mode 100644 tzrec/tests/configs/sid_rqvae_mock.config diff --git a/tzrec/loss/clip_loss_dist_test.py b/tzrec/loss/clip_loss_dist_test.py deleted file mode 100644 index 80d80cf0a..000000000 --- a/tzrec/loss/clip_loss_dist_test.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-process tests for the CLIP distributed all-gather path. - -Validates ``_all_gather_with_grad`` (built on the differentiable -``torch.distributed.nn.functional.all_gather``) and ``MaskedCLIPLoss`` -across ranks. Uses NCCL on GPU when >=2 devices are available (the -production path the reviewer cared about), else falls back to gloo/CPU, -so the test is runnable on a multi-GPU box and in CPU CI alike. -""" - -import os -import unittest - -import numpy as np -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from tzrec.loss.clip_loss import MaskedCLIPLoss -from tzrec.utils import misc_util - -WORLD_SIZE = 2 - - -def _init(rank: int, world_size: int, port: int) -> torch.device: - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size - if use_cuda: - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return torch.device(f"cuda:{rank}") - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - return torch.device("cpu") - - -def _all_gather_worker(rank: int, world_size: int, port: int) -> None: - device = _init(rank, world_size, port) - # Each rank holds a distinct, rank-identifying tensor. - x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) - gathered = MaskedCLIPLoss._all_gather_with_grad([x])[0] - - # Forward: gathered is (world_size*2, 3); rank r contributes rows - # [2r : 2r+2] all equal to (r+1). - assert gathered.shape == (world_size * 2, 3), gathered.shape - for r in range(world_size): - block = gathered[2 * r : 2 * r + 2] - assert torch.allclose(block, torch.full_like(block, float(r + 1))), ( - f"rank{rank}: gathered block {r} wrong: {block}" - ) - - # Backward: identical scalar loss on every rank -> grad to every gathered - # element is 1; the differentiable all_gather sum-reduces across ranks, - # so the local input grad is world_size * ones. - gathered.sum().backward() - assert x.grad is not None, f"rank{rank}: no grad" - assert torch.isfinite(x.grad).all(), f"rank{rank}: non-finite grad" - expected = torch.full_like(x, float(world_size)) - assert torch.allclose(x.grad, expected), f"rank{rank}: grad {x.grad} != {expected}" - dist.destroy_process_group() - - -def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: - device = _init(rank, world_size, port) - torch.manual_seed(1234 + rank) - B, D = 4, 8 - scale = torch.tensor(np.log(1 / 0.07)).exp().to(device) - feats = { - "image_embed": torch.randn(B, D, device=device, requires_grad=True), - "text_embed": torch.randn(B, D, device=device, requires_grad=True), - "image_embed_ori": torch.randn(B, D, device=device), - "text_embed_ori": torch.randn(B, D, device=device), - "logit_scale_self": scale, - "logit_scale_cl": scale, - "logit_scale": scale, - } - mask = torch.ones(B, dtype=torch.bool, device=device) - - loss_fn = MaskedCLIPLoss().to(device) - out = loss_fn(feats, mask) - clip_loss = out["clip_loss"] - assert torch.isfinite(clip_loss).all(), f"rank{rank}: non-finite clip_loss" - assert clip_loss.item() > 0.0, f"rank{rank}: clip_loss not positive" - - clip_loss.backward() - g = feats["image_embed"].grad - assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" - dist.destroy_process_group() - - -def _run(target) -> None: - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join() - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") - - -class ClipLossDistTest(unittest.TestCase): - """2-rank tests for the CLIP distributed collectives.""" - - def test_all_gather_with_grad(self) -> None: - _run(_all_gather_worker) - - def test_masked_clip_loss(self) -> None: - _run(_masked_clip_worker) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/loss/clip_loss_test.py b/tzrec/loss/clip_loss_test.py index 8ba9a2c6b..4436f6b33 100644 --- a/tzrec/loss/clip_loss_test.py +++ b/tzrec/loss/clip_loss_test.py @@ -9,12 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np import torch +import torch.distributed as dist +import torch.multiprocessing as mp from tzrec.loss.clip_loss import MaskedCLIPLoss +from tzrec.utils import misc_util class AllGatherWithGradTest(unittest.TestCase): @@ -142,5 +146,107 @@ def test_mask_holds_under_large_scale(self) -> None: self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) +# --- Multi-process tests for the CLIP distributed all-gather path. --- +# Validates ``_all_gather_with_grad`` (built on the differentiable +# ``torch.distributed.nn.functional.all_gather``) and ``MaskedCLIPLoss`` across +# ranks. Uses NCCL on GPU when >=2 devices are available (the production path the +# reviewer cared about), else falls back to gloo/CPU, so it runs on a multi-GPU +# box and in CPU CI alike. + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _all_gather_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Each rank holds a distinct, rank-identifying tensor. + x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) + gathered = MaskedCLIPLoss._all_gather_with_grad([x])[0] + + # Forward: gathered is (world_size*2, 3); rank r contributes rows + # [2r : 2r+2] all equal to (r+1). + assert gathered.shape == (world_size * 2, 3), gathered.shape + for r in range(world_size): + block = gathered[2 * r : 2 * r + 2] + assert torch.allclose(block, torch.full_like(block, float(r + 1))), ( + f"rank{rank}: gathered block {r} wrong: {block}" + ) + + # Backward: identical scalar loss on every rank -> grad to every gathered + # element is 1; the differentiable all_gather sum-reduces across ranks, + # so the local input grad is world_size * ones. + gathered.sum().backward() + assert x.grad is not None, f"rank{rank}: no grad" + assert torch.isfinite(x.grad).all(), f"rank{rank}: non-finite grad" + expected = torch.full_like(x, float(world_size)) + assert torch.allclose(x.grad, expected), f"rank{rank}: grad {x.grad} != {expected}" + dist.destroy_process_group() + + +def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + torch.manual_seed(1234 + rank) + B, D = 4, 8 + scale = torch.tensor(np.log(1 / 0.07)).exp().to(device) + feats = { + "image_embed": torch.randn(B, D, device=device, requires_grad=True), + "text_embed": torch.randn(B, D, device=device, requires_grad=True), + "image_embed_ori": torch.randn(B, D, device=device), + "text_embed_ori": torch.randn(B, D, device=device), + "logit_scale_self": scale, + "logit_scale_cl": scale, + "logit_scale": scale, + } + mask = torch.ones(B, dtype=torch.bool, device=device) + + loss_fn = MaskedCLIPLoss().to(device) + out = loss_fn(feats, mask) + clip_loss = out["clip_loss"] + assert torch.isfinite(clip_loss).all(), f"rank{rank}: non-finite clip_loss" + assert clip_loss.item() > 0.0, f"rank{rank}: clip_loss not positive" + + clip_loss.backward() + g = feats["image_embed"].grad + assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" + dist.destroy_process_group() + + +def _run(target) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +class ClipLossDistTest(unittest.TestCase): + """2-rank tests for the CLIP distributed collectives.""" + + def test_all_gather_with_grad(self) -> None: + _run(_all_gather_worker) + + def test_masked_clip_loss(self) -> None: + _run(_masked_clip_worker) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 77218ba86..bec32f166 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -22,7 +22,6 @@ import numpy as np import torch import torch.nn.functional as F -import torchmetrics from torch import nn from tzrec.datasets.utils import Batch @@ -33,6 +32,7 @@ ResidualVectorQuantizer, ) from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils.config_util import config_to_kwargs from tzrec.utils.logging_util import logger # Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): @@ -104,13 +104,10 @@ def __init__( # Empty -> default (1.0, 0.5); the quantizer validates the arity. latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) - use_sinkhorn = True - sinkhorn_iters = 5 - sinkhorn_epsilon = 10.0 - if cfg.HasField("sinkhorn_config"): - use_sinkhorn = cfg.sinkhorn_config.enabled - sinkhorn_iters = cfg.sinkhorn_config.iters - sinkhorn_epsilon = cfg.sinkhorn_config.epsilon + # Sinkhorn params from the proto: config_to_kwargs flows the proto + # defaults (enabled=True, iters=5, epsilon=10.0) so the model never + # restates them; keys map to the quantizer's use_sinkhorn/iters/epsilon. + sinkhorn_cfg = config_to_kwargs(cfg.sinkhorn_config) self._encoder = self._build_mlp([self._input_dim, *hidden_dims, embed_dim]) # Decoder is the symmetric reverse of the encoder. @@ -129,18 +126,11 @@ def __init__( latent_weight=latent_weight, rotation_trick=cfg.rotation_trick, kmeans_init=cfg.kmeans_init, - use_sinkhorn=use_sinkhorn, - sinkhorn_iters=sinkhorn_iters, - sinkhorn_epsilon=sinkhorn_epsilon, + use_sinkhorn=sinkhorn_cfg["enabled"], + sinkhorn_iters=sinkhorn_cfg["iters"], + sinkhorn_epsilon=sinkhorn_cfg["epsilon"], ) - # CLIP contrastive head (optional). - if self._use_clip: - self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._masked_clip_loss_fn = MaskedCLIPLoss() - logger.info( "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " "n_layers=%d, n_embed=%s, loss_type=%s, use_clip=%s", @@ -311,6 +301,19 @@ def _predict_mixed( } return predictions + def init_loss(self) -> None: + """Initialize loss modules: the optional CLIP contrastive head. + + The three ``logit_scale`` temperatures and the ``MaskedCLIPLoss`` module + are created here (not in ``__init__``) so loss state lives in one place. + """ + super().init_loss() + if self._use_clip: + self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._masked_clip_loss_fn = MaskedCLIPLoss() + def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: @@ -329,30 +332,3 @@ def loss( if self._use_clip: losses["clip_loss"] = predictions["clip_loss"] return losses - - def init_metric(self) -> None: - """Initialize metric modules (shared eval metrics + train-path mse).""" - super().init_metric() - - # Only the train-path reconstruction needs a metric here; unique_sid_ratio - # is eval-only (its torch.unique forces a per-step GPU->host sync). - self._train_metric_modules["mse"] = torchmetrics.MeanSquaredError() - - def update_train_metric( - self, - predictions: Dict[str, torch.Tensor], - batch: Batch, - ) -> None: - """Update train metric state. - - Overrides the BaseSidModel no-op: RQ-VAE has a train-time reconstruction - (the decoder output), so it reports a train-path mse. Eval metrics are - handled by ``BaseSidModel.update_metric`` (SidRqvae emits ``x_hat``). - - Args: - predictions (dict): a dict of predicted result. - batch (Batch): input batch data. - """ - if "x_hat" in predictions: - embedding = self._extract_feature(batch) - self._train_metric_modules["mse"].update(predictions["x_hat"], embedding) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 77cd7c151..7f5221641 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -12,6 +12,7 @@ import unittest import torch +from parameterized import parameterized from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch @@ -319,72 +320,55 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) self.assertGreater(predictions["clip_loss"].item(), 0.0) - def test_sinkhorn_config_enabled_false(self) -> None: - """``sinkhorn_config { enabled: false }`` must turn Sinkhorn off. - - Previously ``use_sinkhorn`` was hard-coded ``True`` and the proto - block was honored only for iters/epsilon. - """ - n_embed_list = [16] * 2 - sid_rqvae_cfg = sid_model_pb2.SidRqvae( + @parameterized.expand( + [ + ("omitted", None, True), # no sinkhorn_config -> on by default + ("enabled_true", True, True), + ("enabled_false", False, False), # was hard-coded True before + ] + ) + def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: + """``sinkhorn_config.enabled`` (or its omission) drives layer.use_sinkhorn.""" + cfg = sid_model_pb2.SidRqvae( input_dim=32, embed_dim=8, - codebook=n_embed_list, + codebook=[16, 16], forward_mode="ste", loss_type="mse", kmeans_init=False, embedding_feature_name="item_emb", ) - sid_rqvae_cfg.sinkhorn_config.CopyFrom( - sid_model_pb2.SinkhornConfig(enabled=False) - ) - model_config = model_pb2.ModelConfig( - sid_rqvae=sid_rqvae_cfg, + if enabled is not None: + cfg.sinkhorn_config.CopyFrom(sid_model_pb2.SinkhornConfig(enabled=enabled)) + model = SidRqvae( + model_config=model_pb2.ModelConfig(sid_rqvae=cfg), features=[], labels=[] ) - model = SidRqvae(model_config=model_config, features=[], labels=[]) init_parameters(model, device=torch.device("cpu")) - - for layer in model._quantizer.layers: - self.assertFalse(layer.use_sinkhorn) - - def test_sinkhorn_config_default_enabled(self) -> None: - """Omitting ``sinkhorn_config`` preserves on-by-default behavior. - - Back-compat for legacy configs that never set the sub-config. - """ - model = self._create_model() # no sinkhorn_config set for layer in model._quantizer.layers: - self.assertTrue(layer.use_sinkhorn) + self.assertEqual(layer.use_sinkhorn, expect_use_sinkhorn) - def test_loss_type_l1_and_cosine(self) -> None: - """loss_type 'l1' and 'cosine' recon branches run end-to-end. - - Only 'mse' was previously exercised; a typo'd branch would have been - silent. - """ + @parameterized.expand([("mse",), ("l1",), ("cosine",)]) + def test_loss_type_recon_branch(self, loss_type) -> None: + """Each loss_type recon branch runs end-to-end (grad flows).""" B, input_dim = 4, 32 - for loss_type in ("l1", "cosine"): - cfg = sid_model_pb2.SidRqvae( - input_dim=input_dim, - embed_dim=8, - codebook=[16, 16], - forward_mode="ste", - loss_type=loss_type, - kmeans_init=False, - embedding_feature_name="item_emb", - ) - model = SidRqvae( - model_config=model_pb2.ModelConfig(sid_rqvae=cfg), - features=[], - labels=[], - ) - init_parameters(model, device=torch.device("cpu")) - model.train() - model.init_loss() - preds = model.predict(_make_batch(B, input_dim)) - recon = preds["reconstruction_loss"] - self.assertTrue(torch.isfinite(recon), f"{loss_type} recon not finite") - recon.backward() # grad must flow through the decoder + cfg = sid_model_pb2.SidRqvae( + input_dim=input_dim, + embed_dim=8, + codebook=[16, 16], + forward_mode="ste", + loss_type=loss_type, + kmeans_init=False, + embedding_feature_name="item_emb", + ) + model = SidRqvae( + model_config=model_pb2.ModelConfig(sid_rqvae=cfg), features=[], labels=[] + ) + init_parameters(model, device=torch.device("cpu")) + model.train() + model.init_loss() + recon = model.predict(_make_batch(B, input_dim))["reconstruction_loss"] + self.assertTrue(torch.isfinite(recon), f"{loss_type} recon not finite") + recon.backward() # grad must flow through the decoder def test_logit_scale_clamped_prevents_overflow(self) -> None: """A raw logit_scale far above ln(100) must not overflow. diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 8af930fae..29d5cdc17 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -24,7 +24,7 @@ QuantizeForwardMode, ResidualQuantizerOutput, ) -from tzrec.modules.sid.vector_quantize import VectorQuantize +from tzrec.modules.sid.vector_quantize import VectorQuantizeLayer from tzrec.utils.logging_util import logger @@ -167,7 +167,7 @@ def __init__( self.layers = nn.ModuleList( [ - VectorQuantize( + VectorQuantizeLayer( embed_dim=embed_dim, n_embed=self.n_embed_list[i], forward_mode=mode_enum, @@ -233,6 +233,8 @@ def init_embed_(self, data: torch.Tensor) -> None: ) if (not is_ddp) or dist.get_rank() == 0: + # TODO(follow-up): accumulate samples across multiple batches for the + # warm-start fit instead of seeding from only the first training batch. centers = faiss_residual_kmeans( data, self.n_embed_list, @@ -336,18 +338,13 @@ def _apply_rotation_trick( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) - def _train_gumbel(self) -> bool: - """Training pass in Gumbel mode (its soft assignment carries grad).""" - is_gumbel = self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX - return self.training and is_gumbel - def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize one layer's residual via its ``VectorQuantize`` layer. + """Quantize one layer's residual via its ``VectorQuantizeLayer`` layer. STE: raw codebook vector (STE applied on the aggregate in :meth:`forward`). Gumbel: the soft embedding (carries grad directly). @@ -392,7 +389,9 @@ def forward( if self.training: self.init_embed_(input) # first training forward only - train_gumbel = self._train_gumbel() + train_gumbel = ( + self.training and self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + ) # cumulative[i] = sum after layer i. walk_input = input if train_gumbel else input.detach() diff --git a/tzrec/modules/sid/residual_vector_quantizer_dist_test.py b/tzrec/modules/sid/residual_vector_quantizer_dist_test.py deleted file mode 100644 index 4065e943d..000000000 --- a/tzrec/modules/sid/residual_vector_quantizer_dist_test.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Multi-process test for ResidualVectorQuantizer FAISS kmeans-init. - -Validates the DDP path of ``init_embed_``: the codebook is fit on rank 0 -only and broadcast, so every rank ends with a bit-identical warm start. -(The previous behavior averaged permutation-misaligned per-rank centroids, -which the review flagged as a near-random init.) Uses NCCL on GPU when ->=2 devices are available, else gloo/CPU. -""" - -import os -import unittest - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from tzrec.modules.sid.residual_vector_quantizer import ( - ResidualVectorQuantizer, -) -from tzrec.utils import misc_util - -WORLD_SIZE = 2 - - -def _init(rank: int, world_size: int, port: int) -> torch.device: - os.environ["RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = str(port) - use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size - if use_cuda: - torch.cuda.set_device(rank) - dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) - return torch.device(f"cuda:{rank}") - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) - return torch.device("cpu") - - -def _init_embed_worker(rank: int, world_size: int, port: int) -> None: - device = _init(rank, world_size, port) - # Rank-distinct data: a per-rank average/init would diverge; only a - # broadcast-from-rank0 init yields identical codebooks. - torch.manual_seed(rank) - rvq = ResidualVectorQuantizer( - embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True - ).to(device) - rvq.train() - rvq(torch.randn(512, 8, device=device)) # first forward triggers init_embed_ - assert bool(rvq.initted.item()), f"rank{rank}: not initialized" - - for layer in rvq.layers: - w = layer.embedding.weight.detach().clone() - wmin, wmax = w.clone(), w.clone() - dist.all_reduce(wmin, op=dist.ReduceOp.MIN) - dist.all_reduce(wmax, op=dist.ReduceOp.MAX) - assert torch.allclose(wmin, wmax), ( - f"rank{rank}: codebook differs across ranks (init not broadcast)" - ) - dist.destroy_process_group() - - -class ResidualVectorQuantizerDistTest(unittest.TestCase): - """2-rank test for the FAISS kmeans-init broadcast.""" - - def test_init_embed_broadcast(self) -> None: - port = misc_util.get_free_port() - ctx = mp.get_context("spawn") - procs = [] - for rank in range(WORLD_SIZE): - p = ctx.Process(target=_init_embed_worker, args=(rank, WORLD_SIZE, port)) - p.start() - procs.append(p) - for i, p in enumerate(procs): - p.join() - if p.exitcode != 0: - raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 9953a2e05..8fec055ca 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import torch +import torch.distributed as dist +import torch.multiprocessing as mp from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.modules.sid.residual_vector_quantizer import ( @@ -19,6 +22,7 @@ faiss_residual_kmeans, ) from tzrec.modules.sid.types import ResidualQuantizerOutput +from tzrec.utils import misc_util class GumbelResidualVQTest(unittest.TestCase): @@ -271,5 +275,69 @@ def test_commitment_loss_invalid_raises(self) -> None: ) +# --- Multi-process test for ResidualVectorQuantizer FAISS kmeans-init. --- +# Validates the DDP path of ``init_embed_``: the codebook is fit on rank 0 only +# and broadcast, so every rank ends with a bit-identical warm start. (The +# previous behavior averaged permutation-misaligned per-rank centroids, which the +# review flagged as a near-random init.) Uses NCCL on GPU when >=2 devices are +# available, else gloo/CPU. + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _init_embed_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Rank-distinct data: a per-rank average/init would diverge; only a + # broadcast-from-rank0 init yields identical codebooks. + torch.manual_seed(rank) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ).to(device) + rvq.train() + rvq(torch.randn(512, 8, device=device)) # first forward triggers init_embed_ + assert bool(rvq.initted.item()), f"rank{rank}: not initialized" + + for layer in rvq.layers: + w = layer.embedding.weight.detach().clone() + wmin, wmax = w.clone(), w.clone() + dist.all_reduce(wmin, op=dist.ReduceOp.MIN) + dist.all_reduce(wmax, op=dist.ReduceOp.MAX) + assert torch.allclose(wmin, wmax), ( + f"rank{rank}: codebook differs across ranks (init not broadcast)" + ) + dist.destroy_process_group() + + +class ResidualVectorQuantizerDistTest(unittest.TestCase): + """2-rank test for the FAISS kmeans-init broadcast.""" + + def test_init_embed_broadcast(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_init_embed_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 80c25c726..093323b1d 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -100,7 +100,7 @@ def _sinkhorn( return Q.t() # (B, K) -class VectorQuantize(QuantizeLayer): +class VectorQuantizeLayer(QuantizeLayer): """Single codebook vector quantization layer (RQ-VAE backend). A gradient-trained ``nn.Embedding`` codebook (the VQ ``QuantizeLayer``), diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index 21c94d114..e37ac8029 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -16,13 +16,13 @@ from tzrec.modules.sid.types import QuantizeForwardMode from tzrec.modules.sid.vector_quantize import ( - VectorQuantize, + VectorQuantizeLayer, _squared_euclidean_distance, ) class SquaredEuclideanDistanceTest(unittest.TestCase): - """Tests for the squared-L2 distance helper used by VectorQuantize.""" + """Tests for the squared-L2 distance helper used by VectorQuantizeLayer.""" def test_squared_euclidean_distance(self) -> None: x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) @@ -34,7 +34,7 @@ def test_squared_euclidean_distance(self) -> None: class VectorQuantizeTest(unittest.TestCase): - """Tests for a single VectorQuantize layer.""" + """Tests for a single VectorQuantizeLayer layer.""" @parameterized.expand( [ @@ -47,7 +47,7 @@ class VectorQuantizeTest(unittest.TestCase): ) def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: torch.manual_seed(0) - vq = VectorQuantize( + vq = VectorQuantizeLayer( embed_dim=8, n_embed=16, forward_mode=mode, @@ -70,7 +70,7 @@ def test_sinkhorn_balances_assignment(self) -> None: assignment must use more than one code. """ torch.manual_seed(0) - vq = VectorQuantize( + vq = VectorQuantizeLayer( embed_dim=2, n_embed=4, use_sinkhorn=True, sinkhorn_iters=10 ) vq.train() @@ -88,7 +88,7 @@ def test_sinkhorn_balances_assignment(self) -> None: def test_sinkhorn_gumbel_combo_rejected(self) -> None: """Sinkhorn + Gumbel would desync `ids` and `emb`; constructor rejects it.""" with self.assertRaisesRegex(AssertionError, "GUMBEL_SOFTMAX"): - VectorQuantize( + VectorQuantizeLayer( embed_dim=8, n_embed=16, forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, @@ -98,13 +98,13 @@ def test_sinkhorn_gumbel_combo_rejected(self) -> None: def test_sinkhorn_epsilon_must_be_positive(self) -> None: """Reject a non-positive sinkhorn_epsilon (it overflows exp(-cost*eps)).""" with self.assertRaisesRegex(ValueError, "sinkhorn_epsilon"): - VectorQuantize( + VectorQuantizeLayer( embed_dim=8, n_embed=16, use_sinkhorn=True, sinkhorn_epsilon=0.0 ) def test_train_forward_backward_reaches_input(self) -> None: torch.manual_seed(0) - vq = VectorQuantize(embed_dim=8, n_embed=16, use_sinkhorn=False) + vq = VectorQuantizeLayer(embed_dim=8, n_embed=16, use_sinkhorn=False) vq.train() x = torch.randn(5, 8, requires_grad=True) out = vq.quantize(x) @@ -115,7 +115,7 @@ def test_train_forward_backward_reaches_input(self) -> None: def test_eval_forward_is_plain_lookup(self) -> None: torch.manual_seed(0) - vq = VectorQuantize(embed_dim=4, n_embed=8) + vq = VectorQuantizeLayer(embed_dim=4, n_embed=8) vq.eval() x = torch.randn(3, 4) out = vq.quantize(x) @@ -128,7 +128,7 @@ def test_gumbel_train_ids_match_embedding(self) -> None: # (Under the old code ids came from argmin and could disagree with the # gumbel-sampled embedding.) torch.manual_seed(0) - vq = VectorQuantize( + vq = VectorQuantizeLayer( embed_dim=8, n_embed=16, forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, @@ -141,7 +141,7 @@ def test_gumbel_train_ids_match_embedding(self) -> None: def test_gumbel_train_distances_are_differentiable(self) -> None: # Gumbel needs the assignment differentiable: grad must reach the input. torch.manual_seed(0) - vq = VectorQuantize( + vq = VectorQuantizeLayer( embed_dim=8, n_embed=16, forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, diff --git a/tzrec/tests/configs/sid_rqvae_mock.config b/tzrec/tests/configs/sid_rqvae_mock.config new file mode 100644 index 000000000..1446ff872 --- /dev/null +++ b/tzrec/tests/configs/sid_rqvae_mock.config @@ -0,0 +1,55 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqvae_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 2 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + sid_rqvae { + input_dim: 16 + embed_dim: 8 + hidden_dims: 16 + codebook: 16 + codebook: 16 + codebook: 16 + forward_mode: "ste" + loss_type: "mse" + kmeans_init: false + embedding_feature_name: "item_emb" + } +} diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 53f24a1d3..8b09ebd88 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -44,7 +44,12 @@ def tearDown(self): if self.success and os.path.exists(self.test_dir): shutil.rmtree(self.test_dir) - def _prepare_config(self, num_rows: int, dim: int) -> str: + def _prepare_config( + self, + num_rows: int, + dim: int, + base_config: str = "tzrec/tests/configs/sid_rqkmeans_mock.config", + ) -> str: """Write an embedding parquet + a SID config pointed at it. Single dense ``embedding`` column, no labels — SID reads the item @@ -61,9 +66,7 @@ def _prepare_config(self, num_rows: int, dim: int) -> str: # train_input_path set -> load_config_for_test uses it as-is (the # FG_DAG auto-mock path is match-model-specific; SID is single-table). - config = config_util.load_pipeline_config( - "tzrec/tests/configs/sid_rqkmeans_mock.config" - ) + config = config_util.load_pipeline_config(base_config) config.train_input_path = data_glob config.eval_input_path = data_glob config_path = os.path.join(self.test_dir, "sid.config") @@ -117,6 +120,44 @@ def test_sid_rqkmeans_train_eval(self): self.assertLess(metrics["rel_loss"], 1.0) self.assertGreater(metrics["unique_sid_ratio"], 0.0) + @unittest.skipIf( + torch.cuda.is_available(), + "the SID integration tests run on the CPU CI job; forcing CPU on a " + "CUDA-built (GPU) image is unreliable.", + ) + def test_sid_rqvae_train_eval(self): + """End-to-end SidRqvae train -> checkpoint -> eval (gradient-trained). + + Exercises the full RQ-VAE pipeline (encode -> quantize -> decode, + commitment + reconstruction loss, gradient training, checkpoint, eval + metrics). On random data the model need not beat the all-zero baseline, + so only finiteness + nonzero SID variety are asserted (not rel_loss<1). + """ + config_path = self._prepare_config( + num_rows=2048, + dim=16, + base_config="tzrec/tests/configs/sid_rqvae_mock.config", + ) + self.success = utils.test_train_eval(config_path, self.test_dir) + self.assertTrue(self.success) + # save_checkpoints_epochs=1 persists a checkpoint during training. + self.assertTrue( + glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), + "no checkpoint persisted", + ) + result_path = os.path.join(self.test_dir, "train", "eval_result.txt") + self.assertTrue(os.path.exists(result_path), "no eval_result.txt produced") + with open(result_path) as f: + lines = [ln for ln in f.read().splitlines() if ln.strip()] + self.assertTrue(lines, "eval_result.txt is empty") + metrics = json.loads(lines[-1]) + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue( + math.isfinite(metrics[key]), f"{key} not finite: {metrics[key]}" + ) + self.assertGreater(metrics["unique_sid_ratio"], 0.0) + if __name__ == "__main__": unittest.main() From a8c4592b294157a630da971b0c122d34b6742194 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 18 Jun 2026 09:42:15 +0000 Subject: [PATCH 102/129] [fix] SID: review D4/D5 (quantize API) + fix rqvae integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit D4: replace the per-call quantize() `temperature` arg with a `gumbel_temperature` init param across the quantizer chain (QuantizeLayer ABC, VQ/KMeans layers, residual quantizers, SidRqvae). D5: derive the STE wrap in VectorQuantizeLayer.quantize from `forward_mode == STE` instead of an `apply_ste` arg — the RVQ residual walk runs on a detached input, so the per-layer wrap is a numeric no-op and the aggregate STE carries the gradient. CI fix: test_sid_rqvae_train_eval asserted eval_result.txt, but train_eval writes train_eval_result_v2.txt; add the standalone test_eval pass (mirroring the rqkmeans test) so eval_result.txt is produced. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 11 +++---- tzrec/modules/sid/kmeans_quantize.py | 4 +-- tzrec/modules/sid/quantize_layer.py | 2 +- tzrec/modules/sid/quantize_layer_test.py | 2 +- .../modules/sid/residual_kmeans_quantizer.py | 4 +-- tzrec/modules/sid/residual_quantizer.py | 8 ++--- tzrec/modules/sid/residual_quantizer_test.py | 2 +- .../modules/sid/residual_vector_quantizer.py | 21 ++++++------ tzrec/modules/sid/vector_quantize.py | 32 +++++++++---------- tzrec/tests/sid_integration_test.py | 6 ++++ 10 files changed, 42 insertions(+), 50 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index bec32f166..be30975af 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -180,12 +180,10 @@ def _recon_loss( mask = mask.float() return (per_sample * mask).sum() / mask.sum().clamp(min=1) - def _forward_rqvae( - self, x: torch.Tensor, temperature: float = 1.0 - ) -> Dict[str, torch.Tensor]: + def _forward_rqvae(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Standard RQ-VAE forward: encode -> quantize -> decode -> loss.""" z_e = self._encode(x) - quant = self._quantizer(z_e, temperature=temperature) + quant = self._quantizer(z_e) x_hat = self._decode(quant.quantized_embeddings) recon_loss = self._recon_loss(x_hat, x) @@ -204,15 +202,14 @@ def _forward_mixed( fea1: torch.Tensor, fea2: torch.Tensor, clip_mask: torch.Tensor, - temperature: float = 1.0, ) -> Dict[str, torch.Tensor]: """Mixed recon + CLIP forward (all rows dual-pathed; mask splits loss).""" z_e1 = self._encode(fea1) - quant1 = self._quantizer(z_e1, temperature=temperature) + quant1 = self._quantizer(z_e1) x_hat1 = self._decode(quant1.quantized_embeddings) z_e2 = self._encode(fea2) - quant2 = self._quantizer(z_e2, temperature=temperature) + quant2 = self._quantizer(z_e2) x_hat2 = self._decode(quant2.quantized_embeddings) recon_mask = ~clip_mask diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 2641e6963..b17b11e22 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -254,7 +254,7 @@ def _load_from_state_dict( ) @torch.no_grad() - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: """Assign points to the nearest centroid and gather them. Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, @@ -262,11 +262,9 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: (measure zero for real embeddings), where either centroid is valid. Before the FAISS fit (uninitialized) this returns all-zero codes + embeddings so the residual walk stays a no-op and the model is callable. - ``temperature`` is unused (no soft assignment). Args: x (Tensor): data points, shape (B, D). - temperature (float): unused. Returns: QuantizeOutput: ``ids`` (B,) and ``embeddings`` (B, D). diff --git a/tzrec/modules/sid/quantize_layer.py b/tzrec/modules/sid/quantize_layer.py index e7f344fda..5c712589a 100644 --- a/tzrec/modules/sid/quantize_layer.py +++ b/tzrec/modules/sid/quantize_layer.py @@ -39,7 +39,7 @@ def __init__(self, n_embed: int, embed_dim: int) -> None: self.embed_dim = embed_dim @abstractmethod - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" raise NotImplementedError diff --git a/tzrec/modules/sid/quantize_layer_test.py b/tzrec/modules/sid/quantize_layer_test.py index 28eb4849b..4be9648a5 100644 --- a/tzrec/modules/sid/quantize_layer_test.py +++ b/tzrec/modules/sid/quantize_layer_test.py @@ -31,7 +31,7 @@ def __init__(self, n_embed: int, embed_dim: int) -> None: n_embed, embed_dim ) - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: dist = torch.cdist(x, self._codebook) ids = dist.argmin(dim=-1) return QuantizeOutput(embeddings=self.lookup(ids), ids=ids) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index bc7e64300..b9433a3c6 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -82,7 +82,6 @@ def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Nearest-centroid assignment for one layer (delegates to the layer). @@ -92,13 +91,12 @@ def _quantize_layer( Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): unused (no soft assignment). Returns: codes (Tensor): cluster indices, shape (B,). quantized (Tensor): selected centroids, shape (B, D). """ - out = self.layers[layer_idx].quantize(residual, temperature) + out = self.layers[layer_idx].quantize(residual) return out.ids, out.embeddings def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/tzrec/modules/sid/residual_quantizer.py b/tzrec/modules/sid/residual_quantizer.py index 6b80f2e33..2c0c704c7 100644 --- a/tzrec/modules/sid/residual_quantizer.py +++ b/tzrec/modules/sid/residual_quantizer.py @@ -98,17 +98,15 @@ def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Assign one layer's codes and look up its quantized vector. Backend primitive behind the residual walk (encode-direction mirror of - :meth:`_lookup_code`). ``temperature`` is used only by the VQ backend. + :meth:`_lookup_code`). Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): Gumbel-Softmax temperature (VQ only). Returns: codes (Tensor): per-layer cluster ids, shape (B,). @@ -119,7 +117,6 @@ def _quantize_layer( def _residual_pass( self, input: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """Shared residual walk: per-layer assign, subtract, accumulate. @@ -129,7 +126,6 @@ def _residual_pass( Args: input (Tensor): input embeddings, shape (B, D). - temperature (float): forwarded to :meth:`_quantize_layer`. Returns: cluster_ids (Tensor): stacked codes, shape (B, n_layers). @@ -144,7 +140,7 @@ def _residual_pass( for i in range(self.n_layers): if self.normalize_residuals: residual = F.normalize(residual, dim=-1) - codes, quantized = self._quantize_layer(i, residual, temperature) + codes, quantized = self._quantize_layer(i, residual) all_codes.append(codes) aggregated = aggregated + quantized cumulative.append(aggregated) diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index c94cc545d..0f8c3bcf9 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -72,7 +72,7 @@ def __init__(self, embed_dim, n_layers, n_embed=5, normalize_residuals=False): ] ) - def _quantize_layer(self, layer_idx, residual, temperature=1.0): + def _quantize_layer(self, layer_idx, residual): codes = (residual.detach() @ self.books[layer_idx].t()).argmax(dim=-1) return codes, self.books[layer_idx][codes] diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 29d5cdc17..0902fc86f 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -105,6 +105,7 @@ class ResidualVectorQuantizer(ResidualQuantizer): use_sinkhorn (bool): Sinkhorn uniform assignment. Default: True. sinkhorn_iters (int): Sinkhorn iterations. Default: 5. sinkhorn_epsilon (float): Sinkhorn sharpness. Default: 10.0. + gumbel_temperature (float): Gumbel-Softmax temperature. Default: 1.0. """ _FORWARD_MODE_MAP = { @@ -127,6 +128,7 @@ def __init__( use_sinkhorn: bool = True, sinkhorn_iters: int = 5, sinkhorn_epsilon: float = 10.0, + gumbel_temperature: float = 1.0, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) assert commitment_loss in ("l2", "l1", "cos"), ( @@ -175,6 +177,7 @@ def __init__( use_sinkhorn=use_sinkhorn, sinkhorn_iters=sinkhorn_iters, sinkhorn_epsilon=sinkhorn_epsilon, + gumbel_temperature=gumbel_temperature, ) for i in range(n_layers) ] @@ -342,7 +345,6 @@ def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize one layer's residual via its ``VectorQuantizeLayer`` layer. @@ -352,24 +354,22 @@ def _quantize_layer( Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): Gumbel-Softmax temperature. Returns: ids (Tensor): per-layer cluster ids, shape (B,). emb (Tensor): the raw codebook vector (STE/eval) or the soft embedding (Gumbel), with grad, shape (B, D). """ - # apply_ste=False: Gumbel ignores it (returns the soft embedding that - # carries grad); STE/eval get the raw codebook vector in one gather (STE - # is applied on the aggregate in :meth:`forward`), avoiding the discarded - # per-layer straight-through wrap + a second codebook gather. - out = self.layers[layer_idx].quantize(residual, temperature, apply_ste=False) + # On the STE residual walk the residual is detached, so the layer's + # straight-through wrap is a numeric no-op; the real STE gradient comes + # from the aggregate STE in :meth:`forward`. Gumbel returns the soft + # embedding that carries grad directly. + out = self.layers[layer_idx].quantize(residual) return out.ids, out.embeddings def forward( self, input: torch.Tensor, - temperature: float = 1.0, ) -> ResidualQuantizerOutput: """Forward the multi-layer residual quantization. @@ -380,7 +380,6 @@ def forward( Args: input (Tensor): input embeddings, shape (B, D). - temperature (float): temperature for Gumbel-Softmax. Returns: ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, @@ -395,9 +394,7 @@ def forward( # cumulative[i] = sum after layer i. walk_input = input if train_gumbel else input.detach() - cluster_ids, aggregated_quants, cumulative = self._residual_pass( - walk_input, temperature - ) + cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) commitment_loss = torch.mean( torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 093323b1d..4fdb56387 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -121,6 +121,8 @@ class VectorQuantizeLayer(QuantizeLayer): sinkhorn_iters (int): number of Sinkhorn iterations. Default: 5. sinkhorn_epsilon (float): Sinkhorn sharpness parameter for exp(-cost * epsilon). Default: 10.0. + gumbel_temperature (float): Gumbel-Softmax temperature (tau), used only + in GUMBEL_SOFTMAX training. Default: 1.0. """ def __init__( @@ -132,6 +134,7 @@ def __init__( use_sinkhorn: bool = True, sinkhorn_iters: int = 5, sinkhorn_epsilon: float = 10.0, + gumbel_temperature: float = 1.0, ) -> None: super().__init__(n_embed=n_embed, embed_dim=embed_dim) # Sinkhorn drives `ids` (balanced assignment), Gumbel drives `emb` @@ -154,6 +157,7 @@ def __init__( self.use_sinkhorn = use_sinkhorn self.sinkhorn_iters = sinkhorn_iters self.sinkhorn_epsilon = sinkhorn_epsilon + self.gumbel_temperature = gumbel_temperature self.embedding = nn.Embedding(n_embed, embed_dim) nn.init.kaiming_uniform_(self.embedding.weight) @@ -220,23 +224,15 @@ def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: return ids - def quantize( - self, x: torch.Tensor, temperature: float = 1.0, apply_ste: bool = True - ) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). - Commitment loss is computed by the caller - (:meth:`ResidualVectorQuantizer._single_commitment_loss`); device follows - ``x``, so this runs on CPU or GPU unchanged. + Commitment loss is computed by the caller; device follows ``x``, so this + runs on CPU or GPU unchanged. The Gumbel temperature is the + ``gumbel_temperature`` init parameter. Args: x (Tensor): input vectors, shape (B, D). - temperature (float): temperature for Gumbel-Softmax. - apply_ste (bool): in STE training, wrap the embedding with the - straight-through estimator (``x + (q - x).detach()``). Set False - when the caller re-applies STE on the aggregate - (:class:`ResidualVectorQuantizer`): the raw codebook vector is - returned and the otherwise-discarded wrap + re-gather is avoided. Returns: QuantizeOutput: named tuple of (embeddings, ids). @@ -245,17 +241,21 @@ def quantize( # both emb and ids, so the saved code matches the vector used. if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: logits = -self._compute_distances(x) # (B, n_embed), differentiable - weights = F.gumbel_softmax(logits, tau=temperature, hard=True, dim=-1) + weights = F.gumbel_softmax( + logits, tau=self.gumbel_temperature, hard=True, dim=-1 + ) emb = weights @ self.embedding.weight ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) # STE / eval: nearest-neighbour assignment under no_grad, one codebook - # gather. STE wrap only when the caller wants it (standalone use); the - # RVQ caller passes apply_ste=False and re-applies STE on the aggregate. + # gather. In STE training, wrap with the straight-through estimator so + # grad reaches the encoder. (Under the RVQ residual walk the input is + # detached, so this per-layer wrap is a numeric no-op and the aggregate + # STE in ResidualVectorQuantizer.forward carries the gradient.) ids = self._find_nearest_embedding(x) quantized = self.embedding(ids) - if self.training and apply_ste: + if self.training and self.forward_mode == QuantizeForwardMode.STE: quantized = x + (quantized - x).detach() # grad passes to x return QuantizeOutput(embeddings=quantized, ids=ids) diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 8b09ebd88..27a042085 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -139,6 +139,12 @@ def test_sid_rqvae_train_eval(self): base_config="tzrec/tests/configs/sid_rqvae_mock.config", ) self.success = utils.test_train_eval(config_path, self.test_dir) + # train_eval writes train_eval_result_v2.txt; a standalone eval pass + # (like the rqkmeans test) writes eval_result.txt from the checkpoint. + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) self.assertTrue(self.success) # save_checkpoints_epochs=1 persists a checkpoint during training. self.assertTrue( From dbab20b1f6e34e1ab27d8c5ebdff8a0b5b654a7d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 18 Jun 2026 10:02:31 +0000 Subject: [PATCH 103/129] [refactor] SID: config-driven losses via LossConfig sid_loss (review D1-D3) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit D1: add a `oneof sid_loss` to LossConfig (recon_l2/l1/cosine, commitment, sid_clip) and centralize SID loss init/compute in BaseSidModel, mirroring RankModel — `init_loss` registers modules from ModelConfig.losses and `loss()` computes each term from `predictions`. SidRqvae's loss_type / commitment_loss / latent_weight / clip_config proto fields are removed in favor of `losses`. D2: extract the commitment loss into tzrec/loss/commitment_loss.py (CommitmentLoss); the quantizer no longer computes it and instead exposes the per-layer cumulative quantized vectors as ResidualQuantizerOutput.latents. D3: SidRqvae.predict() now returns only the raw tensors the losses consume (x_hat/recon_target/encoder_out/latents + CLIP embeds); all loss math moved to BaseSidModel.loss(). Tests: rewrite sid_rqvae_test + residual_vector_quantizer_test to the new contract, add commitment_loss_test; update sid_rqvae_mock.config. All SID tests green. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/commitment_loss.py | 77 +++++ tzrec/loss/commitment_loss_test.py | 65 ++++ tzrec/models/sid_model.py | 139 ++++++++- tzrec/models/sid_rqvae.py | 241 ++++----------- tzrec/models/sid_rqvae_test.py | 285 +++++++----------- .../modules/sid/residual_vector_quantizer.py | 77 +---- .../sid/residual_vector_quantizer_test.py | 59 +--- tzrec/modules/sid/types.py | 10 +- tzrec/protos/loss.proto | 38 +++ tzrec/protos/models/sid_model.proto | 25 +- tzrec/tests/configs/sid_rqvae_mock.config | 11 +- 11 files changed, 525 insertions(+), 502 deletions(-) create mode 100644 tzrec/loss/commitment_loss.py create mode 100644 tzrec/loss/commitment_loss_test.py diff --git a/tzrec/loss/commitment_loss.py b/tzrec/loss/commitment_loss.py new file mode 100644 index 000000000..5fdb7434a --- /dev/null +++ b/tzrec/loss/commitment_loss.py @@ -0,0 +1,77 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CommitmentLoss: VQ-VAE commitment loss for residual quantizers.""" + +from typing import Sequence + +import torch +import torch.nn.functional as F +from torch import nn + + +class CommitmentLoss(nn.Module): + """Commitment loss between the encoder output and the quantized vectors. + + Operates on a residual quantizer's per-layer cumulative quantized vectors + (the ``latents`` of :class:`~tzrec.modules.sid.types.ResidualQuantizerOutput`). + Both VQ-VAE directions are summed and averaged over the residual layers: + + - ``loss1`` = encoder-toward-quant (gradient flows into the encoder), w1 + - ``loss2`` = quant-toward-encoder (gradient flows into the codebook), w2 + + Args: + latent_weight (Sequence[float]): commitment weights ``[w1, w2]``. + Default: ``(1.0, 0.5)``. + commitment_type (str): distance, ``"l2"``, ``"l1"`` or ``"cos"``. + Default: ``"l2"``. + """ + + def __init__( + self, + latent_weight: Sequence[float] = (1.0, 0.5), + commitment_type: str = "l2", + ) -> None: + super().__init__() + if len(latent_weight) != 2: + raise ValueError( + f"latent_weight must have exactly 2 values [w1, w2], got " + f"{list(latent_weight)}" + ) + assert commitment_type in ("l2", "l1", "cos"), ( + f"commitment_type must be 'l2', 'l1' or 'cos', got {commitment_type!r}" + ) + self.commitment_w1, self.commitment_w2 = latent_weight + self.commitment_type = commitment_type + + def forward(self, encoder_out: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """Compute the commitment loss. + + Args: + encoder_out (Tensor): encoder output (the quantizer input), shape + (B, D). + latents (Tensor): per-layer cumulative quantized vectors, shape + (B, n_layers, D). + + Returns: + Tensor: scalar commitment loss (averaged over layers). + """ + x = encoder_out.unsqueeze(1) # (B, 1, D) -> broadcasts over layers + if self.commitment_type == "cos": + loss1 = (1 - F.cosine_similarity(x, latents.detach(), dim=-1)).mean() + loss2 = (1 - F.cosine_similarity(x.detach(), latents, dim=-1)).mean() + elif self.commitment_type == "l1": + loss1 = (x - latents.detach()).abs().mean() + loss2 = (x.detach() - latents).abs().mean() + else: # "l2" + loss1 = (x - latents.detach()).pow(2.0).mean() + loss2 = (x.detach() - latents).pow(2.0).mean() + return self.commitment_w1 * loss1 + self.commitment_w2 * loss2 diff --git a/tzrec/loss/commitment_loss_test.py b/tzrec/loss/commitment_loss_test.py new file mode 100644 index 000000000..92a0bc09c --- /dev/null +++ b/tzrec/loss/commitment_loss_test.py @@ -0,0 +1,65 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.loss.commitment_loss import CommitmentLoss + + +class CommitmentLossTest(unittest.TestCase): + """Tests for the standalone CommitmentLoss module.""" + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_branch_runs_and_backprops(self, commitment_type) -> None: + """Each commitment_type runs end-to-end; grad reaches both operands.""" + torch.manual_seed(0) + loss_fn = CommitmentLoss( + latent_weight=(1.0, 0.5), commitment_type=commitment_type + ) + B, L, D = 4, 3, 8 + encoder_out = torch.randn(B, D, requires_grad=True) + latents = torch.randn(B, L, D, requires_grad=True) + out = loss_fn(encoder_out, latents) + self.assertEqual(out.shape, ()) + self.assertTrue(torch.isfinite(out)) + out.backward() + # loss1 (encoder-toward-quant) feeds encoder_out; loss2 feeds latents. + self.assertIsNotNone(encoder_out.grad) + self.assertIsNotNone(latents.grad) + self.assertTrue(torch.isfinite(encoder_out.grad).all()) + + def test_latent_weight_wrong_length_raises(self) -> None: + """latent_weight must be exactly [w1, w2].""" + for bad in ([1.0], [1.0, 0.5, 0.25]): + with self.assertRaisesRegex(ValueError, "latent_weight"): + CommitmentLoss(latent_weight=bad) + + def test_invalid_commitment_type_raises(self) -> None: + """An unknown commitment_type is rejected.""" + with self.assertRaisesRegex(AssertionError, "commitment_type"): + CommitmentLoss(commitment_type="bogus") + + def test_weights_scale_the_two_directions(self) -> None: + """w1/w2 weight the encoder-toward-quant / quant-toward-encoder terms.""" + torch.manual_seed(0) + encoder_out = torch.randn(4, 8) + latents = torch.randn(4, 3, 8) + base = CommitmentLoss(latent_weight=(1.0, 0.5), commitment_type="l2") + zero = CommitmentLoss(latent_weight=(0.0, 0.0), commitment_type="l2") + self.assertGreater(base(encoder_out, latents).item(), 0.0) + self.assertEqual(zero(encoder_out, latents).item(), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 8db468799..475994bf6 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -13,16 +13,33 @@ from typing import Any, Dict, List, Optional +import numpy as np import torch +import torch.nn.functional as F import torchmetrics +from torch import nn from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature +from tzrec.loss.clip_loss import MaskedCLIPLoss +from tzrec.loss.commitment_loss import CommitmentLoss from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel +from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig +# Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): +# an unbounded ``logit_scale`` overflows to +Inf -> NaN grad -> corrupt param. +_LOGIT_SCALE_MAX = float(np.log(100)) + +# sid_loss reconstruction variants -> the reduction used by ``_recon_loss``. +_RECON_TYPES = { + "recon_l2_loss": "mse", + "recon_l1_loss": "l1", + "recon_cosine_loss": "cosine", +} + class BaseSidModel(BaseModel): """Shared base for semantic-ID (SID) generation models. @@ -99,12 +116,126 @@ def _extract_feature( return kt[feature_name] def init_loss(self) -> None: - """Initialize loss modules. + """Initialize SID loss modules from ``ModelConfig.losses``. + + Each ``LossConfig`` sets one ``sid_loss`` oneof variant (a reconstruction + loss, the commitment loss, or the CLIP loss). Mirrors ``RankModel``: the + config drives which loss modules are registered, and :meth:`loss` + computes them from ``predictions``. + """ + for loss_cfg in self._base_model_config.losses: + self._init_sid_loss_impl(loss_cfg) + + def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: + """Register the module (if any) for one ``sid_loss`` config.""" + loss_type = loss_cfg.WhichOneof("sid_loss") + if loss_type in _RECON_TYPES: + return # reconstruction losses are functional (no module) + elif loss_type == "commitment_loss": + cfg = loss_cfg.commitment_loss + latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) + self._loss_modules["commitment_loss"] = CommitmentLoss( + latent_weight=latent_weight, + commitment_type=cfg.commitment_type, + ) + elif loss_type == "sid_clip_loss": + # The three learnable CLIP temperatures + the masked-CLIP module. + self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._loss_modules["sid_clip_loss"] = MaskedCLIPLoss() + else: + raise ValueError( + f"LossConfig for a SID model must set a sid_loss variant, " + f"got {loss_type!r}" + ) + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute the configured SID losses from ``predictions``. + + Args: + predictions (dict): a dict of predicted result (the raw tensors the + losses consume — ``x_hat``/``recon_target`` for reconstruction, + ``encoder_out``/``latents`` for commitment, and the CLIP embeds). + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor keyed by the sid_loss variant. + """ + losses: Dict[str, torch.Tensor] = {} + for loss_cfg in self._base_model_config.losses: + losses.update(self._sid_loss_impl(predictions, loss_cfg)) + return losses + + def _sid_loss_impl( + self, predictions: Dict[str, torch.Tensor], loss_cfg: LossConfig + ) -> Dict[str, torch.Tensor]: + """Compute one ``sid_loss`` term from ``predictions``.""" + loss_type = loss_cfg.WhichOneof("sid_loss") + if loss_type in _RECON_TYPES: + loss = self._recon_loss( + predictions["x_hat"], + predictions["recon_target"], + _RECON_TYPES[loss_type], + predictions.get("recon_mask"), + ) + return {loss_type: loss} + elif loss_type == "commitment_loss": + loss = self._loss_modules["commitment_loss"]( + predictions["encoder_out"], predictions["latents"] + ) + return {"commitment_loss": loss} + elif loss_type == "sid_clip_loss": + feats = { + "image_embed": predictions["clip_image"], + "text_embed": predictions["clip_text"], + "image_embed_ori": predictions["clip_image_ori"], + "text_embed_ori": predictions["clip_text_ori"], + "logit_scale_self": self._logit_scale_self.clamp( + max=_LOGIT_SCALE_MAX + ).exp(), + "logit_scale_cl": self._logit_scale_cl.clamp( + max=_LOGIT_SCALE_MAX + ).exp(), + "logit_scale": self._logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp(), + } + out = self._loss_modules["sid_clip_loss"](feats, predictions["clip_mask"]) + return {"sid_clip_loss": out["clip_loss"]} + else: + raise ValueError(f"unsupported sid_loss variant: {loss_type!r}") - SID models compute their losses internally and pass them through - ``predictions``; there is no external loss module to register. + def _recon_loss( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + recon_type: str, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Reconstruction loss for ``recon_type`` ('mse'|'l1'|'cosine'). + + Returns the mean over all rows, or — when ``mask`` (a per-row bool) is + given — the mean over only the masked-in rows (the mixed recon+CLIP path + applies recon loss to recon rows only). No data-dependent branching, so + it stays ``torch.compile``-friendly. + + Args: + x_hat (Tensor): reconstructed output, shape (B, D). + x (Tensor): original input, shape (B, D). + recon_type (str): 'mse', 'l1' or 'cosine'. + mask (Tensor, optional): per-row bool; rows to include. """ - pass + if recon_type == "mse": + per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) + elif recon_type == "l1": + per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + else: # 'cosine' + per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) + if mask is None: + return per_sample.mean() + mask = mask.float() + return (per_sample * mask).sum() / mask.sum().clamp(min=1) def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index be30975af..4c9069a1f 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -11,22 +11,21 @@ """SidRqvae: SID generation model using RQ-VAE (Encoder + VQ + Decoder). -End-to-end differentiable training with reconstruction loss and commitment -loss. Optionally supports CLIP contrastive learning. The encoder/decoder, -residual vector quantizer, and CLIP head all live directly on the model — -there is no intermediate ``RQVAE`` module wrapper. +End-to-end differentiable training. The reconstruction, commitment and optional +CLIP contrastive losses are configured via ``ModelConfig.losses`` (the +``LossConfig`` ``sid_loss`` oneof) and computed centrally in +:meth:`BaseSidModel.loss`; :meth:`predict` only produces the raw tensors those +losses consume. The encoder/decoder and residual vector quantizer live directly +on the model — there is no intermediate ``RQVAE`` module wrapper. """ from typing import Any, Dict, List, Optional -import numpy as np import torch -import torch.nn.functional as F from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature -from tzrec.loss.clip_loss import MaskedCLIPLoss from tzrec.models.sid_model import BaseSidModel from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, @@ -35,10 +34,6 @@ from tzrec.utils.config_util import config_to_kwargs from tzrec.utils.logging_util import logger -# Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): -# an unbounded ``logit_scale`` overflows to +Inf -> NaN grad -> corrupt param. -_LOGIT_SCALE_MAX = float(np.log(100)) - class SidRqvae(BaseSidModel): """SID generation model using RQ-VAE (Encoder + VQ + Decoder). @@ -48,8 +43,9 @@ class SidRqvae(BaseSidModel): Decoder: embed_dim -> ... -> hidden_dims[0] -> input_dim (ReLU between hidden layers; the decoder mirrors the encoder.) - When ``clip_config`` is set, ``predict`` runs a dual path and a masked - CLIP contrastive loss is added for the CLIP-pair rows. + Losses are config-driven (``ModelConfig.losses`` / ``sid_loss`` oneof). When a + ``sid_clip_loss`` is configured, ``predict`` runs a dual (image/text) path and + the masked CLIP contrastive loss is applied to the CLIP-pair rows. Args: model_config (ModelConfig): an instance of ModelConfig. @@ -79,17 +75,18 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config # SidRqvae proto message - self._loss_type = cfg.loss_type - assert self._loss_type in ("mse", "l1", "cosine"), ( - f"loss_type must be 'mse', 'l1' or 'cosine', got '{self._loss_type}'" - ) - self._use_clip = cfg.HasField("clip_config") - self._clip_feature_name = ( - cfg.clip_config.clip_feature_name if self._use_clip else None - ) - self._is_clip_pair_feature_name = ( - cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None - ) + + # CLIP is enabled by a `sid_clip_loss` entry in ModelConfig.losses, which + # also carries the paired-feature names (data wiring). + self._clip_feature_name: Optional[str] = None + self._is_clip_pair_feature_name: Optional[str] = None + for loss_cfg in self._base_model_config.losses: + if loss_cfg.WhichOneof("sid_loss") == "sid_clip_loss": + self._clip_feature_name = loss_cfg.sid_clip_loss.clip_feature_name + self._is_clip_pair_feature_name = ( + loss_cfg.sid_clip_loss.is_clip_pair_feature_name + ) + self._use_clip = self._clip_feature_name is not None embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -101,8 +98,6 @@ def __init__( ) if any(h < 1 for h in hidden_dims): raise ValueError(f"every hidden_dims entry must be >= 1, got {hidden_dims}") - # Empty -> default (1.0, 0.5); the quantizer validates the arity. - latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) # Sinkhorn params from the proto: config_to_kwargs flows the proto # defaults (enabled=True, iters=5, epsilon=10.0) so the model never @@ -122,8 +117,6 @@ def __init__( forward_mode=cfg.forward_mode, normalize_residuals=self._normalize_residuals, distance_type=cfg.distance_type, - commitment_loss=cfg.commitment_loss, - latent_weight=latent_weight, rotation_trick=cfg.rotation_trick, kmeans_init=cfg.kmeans_init, use_sinkhorn=sinkhorn_cfg["enabled"], @@ -133,13 +126,12 @@ def __init__( logger.info( "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " - "n_layers=%d, n_embed=%s, loss_type=%s, use_clip=%s", + "n_layers=%d, n_embed=%s, use_clip=%s", self._input_dim, embed_dim, hidden_dims, self._n_layers, self._n_embed_list, - self._loss_type, self._use_clip, ) @@ -151,96 +143,12 @@ def _decode(self, z_q: torch.Tensor) -> torch.Tensor: """Decode. (B, embed_dim) -> (B, input_dim).""" return self._decoder(z_q) - def _recon_loss( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Reconstruction loss for the configured ``loss_type``. - - Returns the mean over all rows, or — when ``mask`` (a per-row bool) - is given — the mean over only the masked-in rows (the mixed - recon+CLIP path applies recon loss to recon rows only). No - data-dependent branching, so it stays ``torch.compile``-friendly. - - Args: - x_hat (Tensor): reconstructed output, shape (B, D). - x (Tensor): original input, shape (B, D). - mask (Tensor, optional): per-row bool; rows to include. - """ - if self._loss_type == "mse": - per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif self._loss_type == "l1": - per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - else: # 'cosine' - per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) - if mask is None: - return per_sample.mean() - mask = mask.float() - return (per_sample * mask).sum() / mask.sum().clamp(min=1) - - def _forward_rqvae(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: - """Standard RQ-VAE forward: encode -> quantize -> decode -> loss.""" - z_e = self._encode(x) - quant = self._quantizer(z_e) - x_hat = self._decode(quant.quantized_embeddings) - - recon_loss = self._recon_loss(x_hat, x) - quant_loss = quant.quantization_loss - return { - "x_hat": x_hat, - "codes": quant.cluster_ids, - "quantized": quant.quantized_embeddings, - "reconstruction_loss": recon_loss, - "quantization_loss": quant_loss, - "loss": recon_loss + quant_loss, - } - - def _forward_mixed( - self, - fea1: torch.Tensor, - fea2: torch.Tensor, - clip_mask: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP forward (all rows dual-pathed; mask splits loss).""" - z_e1 = self._encode(fea1) - quant1 = self._quantizer(z_e1) - x_hat1 = self._decode(quant1.quantized_embeddings) - - z_e2 = self._encode(fea2) - quant2 = self._quantizer(z_e2) - x_hat2 = self._decode(quant2.quantized_embeddings) - - recon_mask = ~clip_mask - recon_loss = self._recon_loss(x_hat1, fea1, recon_mask) - - features = { - "image_embed": x_hat1, - "text_embed": x_hat2, - "image_embed_ori": fea1, - "text_embed_ori": fea2, - "logit_scale_self": self._logit_scale_self.clamp( - max=_LOGIT_SCALE_MAX - ).exp(), - "logit_scale_cl": self._logit_scale_cl.clamp(max=_LOGIT_SCALE_MAX).exp(), - "logit_scale": self._logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp(), - } - clip_result = self._masked_clip_loss_fn(features, clip_mask) - - commitment = (quant1.quantization_loss + quant2.quantization_loss) / 2 - return { - "codes": quant1.cluster_ids, - "quantized": quant1.quantized_embeddings, - "x_hat": x_hat1, - "reconstruction_loss": recon_loss, - "clip_loss": clip_result["clip_loss"], - "quantization_loss": commitment, - } - def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. + Returns the raw tensors the configured losses consume (computed in + :meth:`BaseSidModel.loss`); inference emits codes only. + Args: batch (Batch): input batch data. @@ -248,84 +156,59 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: predictions (dict): a dict of predicted result. """ embedding = self._extract_feature(batch) - if self._use_clip: return self._predict_mixed(embedding, batch) - else: - return self._predict_rqvae(embedding) + return self._predict_rqvae(embedding) def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: - """Standard RQ-VAE: encode -> quantize -> decode -> loss.""" - result = self._forward_rqvae(embedding) - - # Inference emits codes only (mirrors _predict_mixed); train/eval also - # carry the recon/loss tensors. + """Standard RQ-VAE: encode -> quantize -> decode.""" + z_e = self._encode(embedding) + quant = self._quantizer(z_e) if self._is_inference: - return {"codes": result["codes"]} - + return {"codes": quant.cluster_ids} return { - "codes": result["codes"], - "quantized": result["quantized"], - "x_hat": result["x_hat"], - "reconstruction_loss": result["reconstruction_loss"], - "quantization_loss": result["quantization_loss"], + "codes": quant.cluster_ids, + "x_hat": self._decode(quant.quantized_embeddings), + "recon_target": embedding, + "encoder_out": z_e, + "latents": quant.latents, } def _predict_mixed( self, embedding: torch.Tensor, batch: Batch ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP: extract fea2 and clip_mask, run the dual path.""" - # Inference skips the dual path: fea2 / clip_mask aren't needed - # when we only emit codes. + """Mixed recon + CLIP: dual path over the embedding + its paired feature. + + ``encoder_out`` / ``latents`` stack both paths so the commitment loss + averages over them; ``recon_mask`` (= non-CLIP rows) restricts the recon + loss to reconstruction-only rows. + """ if self._is_inference: - result = self._forward_rqvae(embedding) - return {"codes": result["codes"]} + z_e = self._encode(embedding) + return {"codes": self._quantizer(z_e).cluster_ids} fea2 = self._extract_feature(batch, self._clip_feature_name) - is_clip_pair_raw = self._extract_feature(batch, self._is_clip_pair_feature_name) clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 - result = self._forward_mixed(embedding, fea2, clip_mask) - - predictions: Dict[str, torch.Tensor] = { - "codes": result["codes"], - "quantized": result["quantized"], - "x_hat": result["x_hat"], - "reconstruction_loss": result["reconstruction_loss"], - "clip_loss": result["clip_loss"], - "quantization_loss": result["quantization_loss"], - } - return predictions - - def init_loss(self) -> None: - """Initialize loss modules: the optional CLIP contrastive head. - - The three ``logit_scale`` temperatures and the ``MaskedCLIPLoss`` module - are created here (not in ``__init__``) so loss state lives in one place. - """ - super().init_loss() - if self._use_clip: - self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._masked_clip_loss_fn = MaskedCLIPLoss() - - def loss( - self, predictions: Dict[str, torch.Tensor], batch: Batch - ) -> Dict[str, torch.Tensor]: - """Compute loss of the model. + z_e1 = self._encode(embedding) + quant1 = self._quantizer(z_e1) + x_hat1 = self._decode(quant1.quantized_embeddings) - Args: - predictions (dict): a dict of predicted result. - batch (Batch): input batch data. + z_e2 = self._encode(fea2) + quant2 = self._quantizer(z_e2) + x_hat2 = self._decode(quant2.quantized_embeddings) - Return: - losses (dict): a dict of loss tensor. - """ - losses: Dict[str, torch.Tensor] = {} - losses["reconstruction_loss"] = predictions["reconstruction_loss"] - losses["quantization_loss"] = predictions["quantization_loss"] - if self._use_clip: - losses["clip_loss"] = predictions["clip_loss"] - return losses + return { + "codes": quant1.cluster_ids, + "x_hat": x_hat1, + "recon_target": embedding, + "recon_mask": ~clip_mask, + "encoder_out": torch.cat([z_e1, z_e2], dim=0), + "latents": torch.cat([quant1.latents, quant2.latents], dim=0), + "clip_image": x_hat1, + "clip_text": x_hat2, + "clip_image_ori": embedding, + "clip_text_ori": fea2, + "clip_mask": clip_mask, + } diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 7f5221641..ae53241ef 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -17,7 +17,7 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.models.sid_rqvae import SidRqvae -from tzrec.protos import model_pb2 +from tzrec.protos import loss_pb2, model_pb2 from tzrec.protos.models import sid_model_pb2 from tzrec.utils.state_dict_util import init_parameters @@ -43,39 +43,78 @@ def _make_batch( ) +def _recon_loss_cfg(kind: str = "recon_l2_loss") -> loss_pb2.LossConfig: + """A LossConfig whose sid_loss oneof is the given recon variant.""" + lc = loss_pb2.LossConfig() + getattr(lc, kind).SetInParent() + return lc + + +def _commitment_cfg( + latent_weight=(1.0, 0.5), commitment_type="l2" +) -> loss_pb2.LossConfig: + lc = loss_pb2.LossConfig() + lc.commitment_loss.latent_weight.extend(latent_weight) + lc.commitment_loss.commitment_type = commitment_type + return lc + + +def _clip_cfg() -> loss_pb2.LossConfig: + lc = loss_pb2.LossConfig() + lc.sid_clip_loss.clip_feature_name = "image_emb" + lc.sid_clip_loss.is_clip_pair_feature_name = "is_clip_pair" + return lc + + class SidRqvaeTest(unittest.TestCase): """Tests for SidRqvae model.""" - def _create_model(self, use_clip=False, input_dim=32, embed_dim=8, n_layers=2): - """Helper to create a SidRqvae model with minimal config.""" + def _create_model( + self, + use_clip=False, + input_dim=32, + embed_dim=8, + n_layers=2, + recon="recon_l2_loss", + ): + """Helper to create a SidRqvae model with config-driven losses.""" n_embed_list = [16] * n_layers sid_rqvae_cfg = sid_model_pb2.SidRqvae( input_dim=input_dim, embed_dim=embed_dim, codebook=n_embed_list, forward_mode="ste", - loss_type="mse", kmeans_init=False, embedding_feature_name="item_emb", ) + losses = [_recon_loss_cfg(recon), _commitment_cfg()] if use_clip: - sid_rqvae_cfg.clip_config.CopyFrom( - sid_model_pb2.ClipConfig( - clip_feature_name="image_emb", - is_clip_pair_feature_name="is_clip_pair", - ) - ) + losses.append(_clip_cfg()) # SID models read the item-embedding dense feature directly from the # batch; they do not consume feature_groups, so none is set (which # keeps the config consistent with the empty ``features`` list). - model_config = model_pb2.ModelConfig( - sid_rqvae=sid_rqvae_cfg, - ) + model_config = model_pb2.ModelConfig(sid_rqvae=sid_rqvae_cfg, losses=losses) model = SidRqvae(model_config=model_config, features=[], labels=[]) init_parameters(model, device=torch.device("cpu")) return model + def _clip_batch(self, B, input_dim, is_clip_pair): + return Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "image_emb", "is_clip_pair"], + tensors=[ + torch.randn(B, input_dim), + torch.randn(B, input_dim), + is_clip_pair, + ], + ) + }, + sparse_features={}, + labels={}, + ) + def test_rqvae_train_mode(self) -> None: """Test SidRqvae in train mode: predict -> loss -> metric.""" B, input_dim = 4, 32 @@ -87,44 +126,39 @@ def test_rqvae_train_mode(self) -> None: batch = _make_batch(B, input_dim) predictions = model.predict(batch) - # Train mode should return all fields + # predict() returns only the raw tensors the losses consume. self.assertIn("codes", predictions) - self.assertIn("quantized", predictions) self.assertIn("x_hat", predictions) - self.assertIn("reconstruction_loss", predictions) - self.assertIn("quantization_loss", predictions) + self.assertIn("encoder_out", predictions) + self.assertIn("latents", predictions) self.assertEqual(predictions["codes"].shape[0], B) - # Loss should return reconstruction_loss + quantization_loss + # loss() computes the configured recon + commitment terms. losses = model.loss(predictions, batch) - self.assertIn("reconstruction_loss", losses) - self.assertIn("quantization_loss", losses) + self.assertIn("recon_l2_loss", losses) + self.assertIn("commitment_loss", losses) - # Total loss should be a scalar and have grad total_loss = sum(losses.values()) self.assertTrue(total_loss.requires_grad) - # Metric update should not raise model.update_metric(predictions, batch, losses) metrics = model.compute_metric() self.assertIn("mse", metrics) self.assertIn("unique_sid_ratio", metrics) def test_rqvae_eval_mode(self) -> None: - """Test SidRqvae in eval mode: predict returns all fields.""" + """Test SidRqvae in eval mode: predict returns the recon fields.""" B, input_dim = 4, 32 model = self._create_model(input_dim=input_dim) model.eval() - batch = _make_batch(B, input_dim) - predictions = model.predict(batch) + predictions = model.predict(_make_batch(B, input_dim)) - # Eval mode (not inference) should return all fields + # Eval mode (not inference) exposes x_hat for the metric + losses. self.assertIn("codes", predictions) - self.assertIn("quantized", predictions) self.assertIn("x_hat", predictions) - self.assertIn("reconstruction_loss", predictions) - self.assertIn("quantization_loss", predictions) + self.assertIn("encoder_out", predictions) + self.assertIn("latents", predictions) def test_rqvae_inference_mode(self) -> None: """Test SidRqvae in inference mode: only codes returned.""" @@ -133,13 +167,10 @@ def test_rqvae_inference_mode(self) -> None: model.eval() model.set_is_inference(True) - batch = _make_batch(B, input_dim) - predictions = model.predict(batch) - - # Inference mode should only return codes + predictions = model.predict(_make_batch(B, input_dim)) self.assertIn("codes", predictions) self.assertNotIn("x_hat", predictions) - self.assertNotIn("reconstruction_loss", predictions) + self.assertNotIn("latents", predictions) def test_rqvae_clip_mode(self) -> None: """Test SidRqvae with CLIP mixed mode (mixed recon + clip batch).""" @@ -148,45 +179,23 @@ def test_rqvae_clip_mode(self) -> None: model.train() model.init_loss() - # Build mixed batch: first half recon, second half clip. - # With the explicit is_clip_pair column the actual tensor values - # no longer matter — the flag column drives routing. - item_emb = torch.randn(B, input_dim) - image_emb = torch.randn(B, input_dim) is_clip_pair = torch.zeros(B, 1) - is_clip_pair[B // 2 :] = 1.0 # clip rows - - batch = Batch( - dense_features={ - BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], - tensors=[item_emb, image_emb, is_clip_pair], - ) - }, - sparse_features={}, - labels={}, - ) + is_clip_pair[B // 2 :] = 1.0 # second half clip + batch = self._clip_batch(B, input_dim, is_clip_pair) predictions = model.predict(batch) - - # Mixed mode returns reconstruction_loss, clip_loss, quantization_loss self.assertIn("codes", predictions) - self.assertIn("reconstruction_loss", predictions) - self.assertIn("clip_loss", predictions) - self.assertIn("quantization_loss", predictions) self.assertIn("x_hat", predictions) + self.assertIn("clip_image", predictions) self.assertEqual(predictions["codes"].shape[0], B) - # Loss should return all three losses = model.loss(predictions, batch) - self.assertIn("reconstruction_loss", losses) - self.assertIn("clip_loss", losses) - self.assertIn("quantization_loss", losses) + self.assertIn("recon_l2_loss", losses) + self.assertIn("commitment_loss", losses) + self.assertIn("sid_clip_loss", losses) total_loss = sum(losses.values()) self.assertTrue(total_loss.requires_grad) - - # Backward should work total_loss.backward() has_grad = any( p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() @@ -194,66 +203,28 @@ def test_rqvae_clip_mode(self) -> None: self.assertTrue(has_grad) def test_rqvae_clip_all_recon(self) -> None: - """Test mixed mode with all-recon batch (edge case).""" + """Mixed mode with all-recon batch: clip term 0, recon term > 0.""" B, input_dim = 4, 32 model = self._create_model(input_dim=input_dim, use_clip=True) model.train() model.init_loss() - # All recon: is_clip_pair = 0 everywhere - item_emb = torch.randn(B, input_dim) - image_emb = torch.randn(B, input_dim) - is_clip_pair = torch.zeros(B, 1) - - batch = Batch( - dense_features={ - BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], - tensors=[item_emb, image_emb, is_clip_pair], - ) - }, - sparse_features={}, - labels={}, - ) - - predictions = model.predict(batch) - model.loss(predictions, batch) - - # clip_loss should be 0 (no clip rows) - self.assertEqual(predictions["clip_loss"].item(), 0.0) - # reconstruction_loss should be > 0 - self.assertGreater(predictions["reconstruction_loss"].item(), 0.0) + batch = self._clip_batch(B, input_dim, torch.zeros(B, 1)) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["sid_clip_loss"].item(), 0.0) + self.assertGreater(losses["recon_l2_loss"].item(), 0.0) def test_rqvae_clip_all_clip(self) -> None: - """Test mixed mode with all-clip batch (edge case).""" + """Mixed mode with all-clip batch: recon term 0, clip term > 0.""" B, input_dim = 4, 32 model = self._create_model(input_dim=input_dim, use_clip=True) model.train() model.init_loss() - # All clip: is_clip_pair = 1 everywhere - item_emb = torch.randn(B, input_dim) - image_emb = torch.randn(B, input_dim) - is_clip_pair = torch.ones(B, 1) - - batch = Batch( - dense_features={ - BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], - tensors=[item_emb, image_emb, is_clip_pair], - ) - }, - sparse_features={}, - labels={}, - ) - - predictions = model.predict(batch) - model.loss(predictions, batch) - - # reconstruction_loss should be 0 (no recon rows) - self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) - # clip_loss should be > 0 - self.assertGreater(predictions["clip_loss"].item(), 0.0) + batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["recon_l2_loss"].item(), 0.0) + self.assertGreater(losses["sid_clip_loss"].item(), 0.0) def test_rqvae_backward(self) -> None: """Test that backward pass works without errors.""" @@ -263,37 +234,33 @@ def test_rqvae_backward(self) -> None: model.init_loss() batch = _make_batch(B, input_dim) - predictions = model.predict(batch) - losses = model.loss(predictions, batch) - total_loss = sum(losses.values()) - total_loss.backward() + losses = model.loss(model.predict(batch), batch) + sum(losses.values()).backward() - # Encoder params should have gradients has_grad = any( p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() ) self.assertTrue(has_grad) - def test_latent_weight_wrong_length_raises(self) -> None: - """latent_weight must be exactly [w1, w2]; a bad length fails fast.""" + def test_commitment_latent_weight_wrong_length_raises(self) -> None: + """A commitment_loss with a bad latent_weight length fails in init_loss.""" for bad in ([1.0], [1.0, 0.5, 0.25]): cfg = sid_model_pb2.SidRqvae( - input_dim=32, - embed_dim=8, - codebook=[16, 16], - kmeans_init=False, - latent_weight=bad, + input_dim=32, embed_dim=8, codebook=[16, 16], kmeans_init=False ) - model_config = model_pb2.ModelConfig(sid_rqvae=cfg) + model_config = model_pb2.ModelConfig( + sid_rqvae=cfg, losses=[_commitment_cfg(latent_weight=bad)] + ) + model = SidRqvae(model_config=model_config, features=[], labels=[]) with self.assertRaisesRegex(ValueError, "latent_weight"): - SidRqvae(model_config=model_config, features=[], labels=[]) + model.init_loss() def test_clip_mask_uses_flag_not_equality(self) -> None: """The is_clip_pair flag, not bit-exact equality, drives routing. Build a batch where ``image_emb == item_emb`` numerically but - ``is_clip_pair=1``: row must route to the CLIP branch (under the - old bit-exact logic it would have been silently relabeled recon). + ``is_clip_pair=1``: rows must route to the CLIP branch (under the old + bit-exact logic they would have been silently relabeled recon). """ B, input_dim = 4, 32 model = self._create_model(input_dim=input_dim, use_clip=True) @@ -301,24 +268,19 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: model.init_loss() item_emb = torch.randn(B, input_dim) - image_emb = item_emb.clone() # bit-identical - is_clip_pair = torch.ones(B, 1) # but flagged as clip - batch = Batch( dense_features={ BASE_DATA_GROUP: KeyedTensor.from_tensor_list( keys=["item_emb", "image_emb", "is_clip_pair"], - tensors=[item_emb, image_emb, is_clip_pair], + tensors=[item_emb, item_emb.clone(), torch.ones(B, 1)], ) }, sparse_features={}, labels={}, ) - - predictions = model.predict(batch) - # All rows flagged as clip -> reconstruction_loss should be 0, clip_loss > 0 - self.assertEqual(predictions["reconstruction_loss"].item(), 0.0) - self.assertGreater(predictions["clip_loss"].item(), 0.0) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["recon_l2_loss"].item(), 0.0) + self.assertGreater(losses["sid_clip_loss"].item(), 0.0) @parameterized.expand( [ @@ -334,7 +296,6 @@ def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: embed_dim=8, codebook=[16, 16], forward_mode="ste", - loss_type="mse", kmeans_init=False, embedding_feature_name="item_emb", ) @@ -347,28 +308,25 @@ def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: for layer in model._quantizer.layers: self.assertEqual(layer.use_sinkhorn, expect_use_sinkhorn) - @parameterized.expand([("mse",), ("l1",), ("cosine",)]) - def test_loss_type_recon_branch(self, loss_type) -> None: - """Each loss_type recon branch runs end-to-end (grad flows).""" + @parameterized.expand( + [ + ("recon_l2_loss",), + ("recon_l1_loss",), + ("recon_cosine_loss",), + ] + ) + def test_recon_loss_variant_branch(self, recon) -> None: + """Each recon variant runs end-to-end (grad flows through the decoder).""" B, input_dim = 4, 32 - cfg = sid_model_pb2.SidRqvae( - input_dim=input_dim, - embed_dim=8, - codebook=[16, 16], - forward_mode="ste", - loss_type=loss_type, - kmeans_init=False, - embedding_feature_name="item_emb", - ) - model = SidRqvae( - model_config=model_pb2.ModelConfig(sid_rqvae=cfg), features=[], labels=[] - ) - init_parameters(model, device=torch.device("cpu")) + model = self._create_model(input_dim=input_dim, recon=recon) model.train() model.init_loss() - recon = model.predict(_make_batch(B, input_dim))["reconstruction_loss"] - self.assertTrue(torch.isfinite(recon), f"{loss_type} recon not finite") - recon.backward() # grad must flow through the decoder + losses = model.loss( + model.predict(_make_batch(B, input_dim)), _make_batch(B, input_dim) + ) + recon_loss = losses[recon] + self.assertTrue(torch.isfinite(recon_loss), f"{recon} not finite") + recon_loss.backward() # grad must flow through the decoder def test_logit_scale_clamped_prevents_overflow(self) -> None: """A raw logit_scale far above ln(100) must not overflow. @@ -386,22 +344,9 @@ def test_logit_scale_clamped_prevents_overflow(self) -> None: model._logit_scale_cl.fill_(100.0) model._logit_scale.fill_(100.0) - batch = Batch( - dense_features={ - BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], - tensors=[ - torch.randn(B, input_dim), - torch.randn(B, input_dim), - torch.ones(B, 1), - ], - ) - }, - sparse_features={}, - labels={}, - ) + batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) - self.assertTrue(torch.isfinite(losses["clip_loss"])) + self.assertTrue(torch.isfinite(losses["sid_clip_loss"])) sum(losses.values()).backward() for p in ( model._logit_scale_self, diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 0902fc86f..d6e7fa2aa 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -11,7 +11,7 @@ """ResidualVectorQuantizer: multi-layer residual VQ with gradient training.""" -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -92,12 +92,6 @@ class ResidualVectorQuantizer(ResidualQuantizer): normalize_residuals (bool): L2-normalize residuals before each quantization layer. Default: False. distance_type (str): distance metric, 'l2' or 'cosine'. Default: 'l2'. - commitment_loss (str): commitment loss type, 'l2', 'l1' or 'cos'. - Default: 'l2'. - latent_weight (List[float]): commitment loss weights [w1, w2]. - w1: x toward quant (encoder side). - w2: quant toward x (codebook side). - Default: [1.0, 0.5]. rotation_trick (bool): use rotation trick for improved STE gradient estimation (arXiv:2410.06424). Default: False. kmeans_init (bool): use residual K-Means codebook initialization @@ -121,8 +115,6 @@ def __init__( forward_mode: str = "ste", normalize_residuals: bool = False, distance_type: str = "l2", - commitment_loss: str = "l2", - latent_weight: Sequence[float] = (1.0, 0.5), rotation_trick: bool = False, kmeans_init: bool = False, use_sinkhorn: bool = True, @@ -131,19 +123,8 @@ def __init__( gumbel_temperature: float = 1.0, ) -> None: super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) - assert commitment_loss in ("l2", "l1", "cos"), ( - f"commitment_loss must be 'l2', 'l1' or 'cos', got {commitment_loss!r}" - ) - self.commitment_loss_type = commitment_loss self.rotation_trick = rotation_trick - if len(latent_weight) != 2: - raise ValueError( - f"latent_weight must have exactly 2 values [w1, w2], got " - f"{list(latent_weight)}" - ) - self.commitment_w1, self.commitment_w2 = latent_weight - # ``initted`` is the kmeans_init guard: True means "codebook has # been seeded", so init_embed_() becomes a no-op on later forwards. self.register_buffer("initted", torch.tensor([not kmeans_init])) @@ -186,17 +167,14 @@ def __init__( logger.info( "ResidualVectorQuantizer init: embed_dim=%d, n_layers=%d, " "n_embed=%s, forward_mode=%s, normalize_residuals=%s, " - "distance_type=%s, commitment_loss=%s, latent_weight=%s, " - "rotation_trick=%s, kmeans_init=%s, use_sinkhorn=%s, " - "sinkhorn_iters=%d, sinkhorn_epsilon=%s", + "distance_type=%s, rotation_trick=%s, kmeans_init=%s, " + "use_sinkhorn=%s, sinkhorn_iters=%d, sinkhorn_epsilon=%s", embed_dim, n_layers, n_embed, forward_mode, normalize_residuals, distance_type, - commitment_loss, - list(latent_weight), rotation_trick, kmeans_init, use_sinkhorn, @@ -257,44 +235,6 @@ def init_embed_(self, data: torch.Tensor) -> None: self.initted.fill_(True) - def _single_commitment_loss( - self, - x: torch.Tensor, - quant: torch.Tensor, - ) -> torch.Tensor: - """Commitment loss for a single cumulative quantization tensor. - - - cos: (1 - cosine_similarity) * weight - - l2: (x - quant)^2.mean() * weight - - l1: |x - quant|.mean() * weight - - Both directions are always summed: - loss1 = encoder-toward-quant (gradient flows into encoder) - loss2 = quant-toward-encoder (gradient flows into codebook) - - Args: - x (Tensor): original input, shape (B, D). - quant (Tensor): cumulative quantized output at one layer, - shape (B, D). - - Returns: - Tensor: scalar commitment loss for this layer. - """ - if self.commitment_loss_type == "cos": - loss1 = ( - 1 - F.cosine_similarity(x, quant.detach(), dim=-1) - ).mean() * self.commitment_w1 - loss2 = ( - 1 - F.cosine_similarity(x.detach(), quant, dim=-1) - ).mean() * self.commitment_w2 - elif self.commitment_loss_type == "l1": - loss1 = (x - quant.detach()).abs().mean() * self.commitment_w1 - loss2 = (x.detach() - quant).abs().mean() * self.commitment_w2 - else: # 'l2' - loss1 = (x - quant.detach()).pow(2.0).mean() * self.commitment_w1 - loss2 = (x.detach() - quant).pow(2.0).mean() * self.commitment_w2 - return loss1 + loss2 - @staticmethod def _apply_rotation_trick( x: torch.Tensor, @@ -383,7 +323,7 @@ def forward( Returns: ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, - quantization_loss). + latents). """ if self.training: self.init_embed_(input) # first training forward only @@ -396,9 +336,10 @@ def forward( walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) - commitment_loss = torch.mean( - torch.stack([self._single_commitment_loss(input, c) for c in cumulative]) - ) + # Expose the per-layer cumulative quantized vectors (grad-carrying on the + # codebook side) so the model-side CommitmentLoss can consume them; the + # commitment loss is no longer computed inside the quantizer. + latents = torch.stack(cumulative, dim=1) # (B, n_layers, D) # Aggregate STE (STE only; Gumbel already carries grad). quants_trunc = aggregated_quants @@ -411,7 +352,7 @@ def forward( return ResidualQuantizerOutput( cluster_ids=cluster_ids, quantized_embeddings=quants_trunc, - quantization_loss=commitment_loss, + latents=latents, ) @torch.no_grad() diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 8fec055ca..4aee34322 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -181,7 +181,9 @@ def test_forward_output(self) -> None: self.assertIsInstance(out, ResidualQuantizerOutput) self.assertEqual(out.cluster_ids.shape, (5, 3)) self.assertEqual(out.quantized_embeddings.shape, (5, 8)) - self.assertTrue(torch.isfinite(out.quantization_loss).all()) + # latents: per-layer cumulative quantized vectors (B, n_layers, D). + self.assertEqual(out.latents.shape, (5, 3, 8)) + self.assertTrue(torch.isfinite(out.latents).all()) def test_decode_codes_shared_base(self) -> None: codes = torch.randint(0, 16, (5, 3)) @@ -220,61 +222,6 @@ def test_faiss_kmeans_init_seeds_codebook(self) -> None: self.assertGreater(layer.embedding.weight.abs().sum().item(), 0.0) -class CommitmentLossTest(unittest.TestCase): - """ResidualVectorQuantizer commitment-loss branches (l1 / cos / invalid).""" - - def test_commitment_loss_l1_branch(self) -> None: - """The commitment_loss='l1' branch runs end-to-end (no fall-through to l2).""" - torch.manual_seed(0) - rq = ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - forward_mode="ste", - commitment_loss="l1", - kmeans_init=False, - use_sinkhorn=False, - ) - for layer in rq.layers: - torch.nn.init.normal_(layer.embedding.weight, std=0.1) - x = torch.randn(4, 8, requires_grad=True) - out = rq(x) - self.assertTrue(torch.isfinite(out.quantization_loss)) - out.quantization_loss.backward() - self.assertIsNotNone(x.grad) - - def test_commitment_loss_cos_branch(self) -> None: - """The commitment_loss='cos' branch runs end-to-end.""" - torch.manual_seed(0) - rq = ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - forward_mode="ste", - commitment_loss="cos", - kmeans_init=False, - use_sinkhorn=False, - ) - for layer in rq.layers: - torch.nn.init.normal_(layer.embedding.weight, std=0.1) - x = torch.randn(4, 8, requires_grad=True) - out = rq(x) - self.assertTrue(torch.isfinite(out.quantization_loss)) - out.quantization_loss.backward() - self.assertIsNotNone(x.grad) - - def test_commitment_loss_invalid_raises(self) -> None: - """ResidualVectorQuantizer rejects unknown commitment_loss spellings.""" - with self.assertRaisesRegex(AssertionError, "commitment_loss"): - ResidualVectorQuantizer( - embed_dim=8, - n_layers=2, - n_embed=4, - commitment_loss="bogus", - use_sinkhorn=False, - ) - - # --- Multi-process test for ResidualVectorQuantizer FAISS kmeans-init. --- # Validates the DDP path of ``init_embed_``: the codebook is fit on rank 0 only # and broadcast, so every rank ends with a bit-identical warm start. (The diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py index ccb5caffe..7b110f023 100644 --- a/tzrec/modules/sid/types.py +++ b/tzrec/modules/sid/types.py @@ -44,12 +44,18 @@ class QuantizeOutput(NamedTuple): class ResidualQuantizerOutput(NamedTuple): """Output of the residual quantization module (RQ-VAE backend). + The commitment loss is no longer computed inside the quantizer; the per-layer + cumulative quantized vectors are exposed as ``latents`` so the model-side + commitment loss (:class:`~tzrec.loss.commitment_loss.CommitmentLoss`) can + consume them. + Attributes: cluster_ids (Tensor): codebook indices per layer, shape (B, n_layers). quantized_embeddings (Tensor): sum of quantized embeddings, shape (B, D). - quantization_loss (Tensor): total commitment loss scalar. + latents (Tensor): per-layer cumulative quantized vectors, shape + (B, n_layers, D) (``latents[:, i]`` is the sum after layer ``i``). """ cluster_ids: torch.Tensor quantized_embeddings: torch.Tensor - quantization_loss: torch.Tensor + latents: torch.Tensor diff --git a/tzrec/protos/loss.proto b/tzrec/protos/loss.proto index 6468cc60e..466479cb2 100644 --- a/tzrec/protos/loss.proto +++ b/tzrec/protos/loss.proto @@ -9,6 +9,44 @@ message LossConfig { JRCLoss jrc_loss = 4; BinaryFocalLoss binary_focal_loss = 5; } + // Losses for semantic-ID (SID) generation models (SidRqvae). A SID model + // lists one LossConfig per term it trains on (a reconstruction loss, the + // commitment loss, and optionally the CLIP contrastive loss). + oneof sid_loss { + ReconL2Loss recon_l2_loss = 6; + ReconL1Loss recon_l1_loss = 7; + ReconCosineLoss recon_cosine_loss = 8; + CommitmentLoss commitment_loss = 9; + SidClipLoss sid_clip_loss = 10; + } +} + +// RQ-VAE reconstruction losses (input vs. decoder output). +message ReconL2Loss { +} + +message ReconL1Loss { +} + +message ReconCosineLoss { +} + +// RQ-VAE commitment loss between the encoder output and the per-layer +// cumulative quantized vectors. +message CommitmentLoss { + // Commitment loss weights [w1, w2] (encoder-toward-quant, quant-toward- + // encoder). Defaults to [1.0, 0.5] when unset. + repeated float latent_weight = 1; + // Distance used for the commitment term: "l2", "l1" or "cos". + optional string commitment_type = 2 [default = "l2"]; +} + +// CLIP contrastive loss for the dual (image/text) reconstruction path. +message SidClipLoss { + // Name of the second (paired) embedding feature inside the input Batch. + required string clip_feature_name = 1; + // Name of the per-row float feature flagging CLIP-pair rows (>0.5 = pair). + required string is_clip_pair_feature_name = 2; } message BinaryCrossEntropy { diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index fa17321f8..4eb3ffffd 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -26,19 +26,6 @@ message SinkhornConfig { optional bool enabled = 3 [default = true]; } -message ClipConfig { - // Name of the second feature (paired with embedding_feature_name - // to form a contrastive pair). - required string clip_feature_name = 1; - // Name of the per-sample boolean feature (0/1, value_dim=1) that - // flags whether the row is a CLIP pair (1) or a reconstruction-only - // row (0). Required for mixed recon+clip batches: the model uses - // this column directly as the ``clip_mask``. Replaces the prior - // bit-exact ``embedding == fea2`` discrimination, which silently - // mislabeled rows on any upstream float cast / normalization. - required string is_clip_pair_feature_name = 2; -} - message SidRqvae { // === Network structure === // Input embedding dimension. @@ -60,11 +47,6 @@ message SidRqvae { optional bool normalize_residuals = 7 [default = false]; // Distance metric: "l2" or "cosine". optional string distance_type = 9 [default = "l2"]; - // Commitment loss type: "l2", "l1" or "cos". - optional string commitment_loss = 10 [default = "l2"]; - // Commitment loss weights [w1, w2]. Defaults to [1.0, 0.5] when unset - // (applied by SidRqvae / ResidualVectorQuantizer). - repeated float latent_weight = 11; // STE rotation trick. optional bool rotation_trick = 12 [default = false]; // KMeans codebook initialization on first training forward. Default false. @@ -85,11 +67,10 @@ message SidRqvae { // omitted: enabled with iters=5, epsilon=10.0. Include the sub-config // to override params; set ``enabled: false`` inside it to disable. optional SinkhornConfig sinkhorn_config = 15; - // CLIP contrastive learning (disabled when unset). - optional ClipConfig clip_config = 16; - // Reconstruction loss type: "mse", "l1", or "cosine". - optional string loss_type = 20 [default = "mse"]; + // Reconstruction, commitment and (optional) CLIP losses are configured via + // ModelConfig.losses (the LossConfig ``sid_loss`` oneof), not on this + // message — see tzrec/protos/loss.proto. // Name of the item embedding feature inside the input Batch. optional string embedding_feature_name = 40 [default = "item_emb"]; diff --git a/tzrec/tests/configs/sid_rqvae_mock.config b/tzrec/tests/configs/sid_rqvae_mock.config index 1446ff872..f94179621 100644 --- a/tzrec/tests/configs/sid_rqvae_mock.config +++ b/tzrec/tests/configs/sid_rqvae_mock.config @@ -48,8 +48,17 @@ model_config { codebook: 16 codebook: 16 forward_mode: "ste" - loss_type: "mse" kmeans_init: false embedding_feature_name: "item_emb" } + losses { + recon_l2_loss { + } + } + losses { + commitment_loss { + latent_weight: 1.0 + latent_weight: 0.5 + } + } } From 8f0d8824e96c6db6c1177846c3b321842501371d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 03:23:39 +0000 Subject: [PATCH 104/129] [refactor] SID: /simplify cleanups on the config-driven loss refactor - _recon_loss: use div_no_nan for the masked mean (house style) - predict: hoist the inference short-circuit to one get_codes guard (no decoder output / commitment latents built at inference) - dedup CLIP logit_scale init + clamp/exp (constant + local helper) - slim _RECON_TYPES dict -> _RECON_LOSSES frozenset; _recon_loss branches on the sid_loss variant names directly - clip-config lookup via next() Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 53 ++++++++++++++++++++------------------- tzrec/models/sid_rqvae.py | 32 ++++++++++++----------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 475994bf6..0bb1a1799 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -26,19 +26,18 @@ from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel +from tzrec.modules.utils import div_no_nan from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig # Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): # an unbounded ``logit_scale`` overflows to +Inf -> NaN grad -> corrupt param. _LOGIT_SCALE_MAX = float(np.log(100)) +# CLIP temperature init (reference CLIP: log(1 / 0.07)). +_LOGIT_SCALE_INIT = float(np.log(1 / 0.07)) -# sid_loss reconstruction variants -> the reduction used by ``_recon_loss``. -_RECON_TYPES = { - "recon_l2_loss": "mse", - "recon_l1_loss": "l1", - "recon_cosine_loss": "cosine", -} +# sid_loss reconstruction variants (``_recon_loss`` branches on these directly). +_RECON_LOSSES = frozenset(("recon_l2_loss", "recon_l1_loss", "recon_cosine_loss")) class BaseSidModel(BaseModel): @@ -129,7 +128,7 @@ def init_loss(self) -> None: def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: """Register the module (if any) for one ``sid_loss`` config.""" loss_type = loss_cfg.WhichOneof("sid_loss") - if loss_type in _RECON_TYPES: + if loss_type in _RECON_LOSSES: return # reconstruction losses are functional (no module) elif loss_type == "commitment_loss": cfg = loss_cfg.commitment_loss @@ -140,9 +139,9 @@ def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: ) elif loss_type == "sid_clip_loss": # The three learnable CLIP temperatures + the masked-CLIP module. - self._logit_scale_self = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale_cl = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - self._logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self._logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self._logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self._logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) self._loss_modules["sid_clip_loss"] = MaskedCLIPLoss() else: raise ValueError( @@ -174,11 +173,11 @@ def _sid_loss_impl( ) -> Dict[str, torch.Tensor]: """Compute one ``sid_loss`` term from ``predictions``.""" loss_type = loss_cfg.WhichOneof("sid_loss") - if loss_type in _RECON_TYPES: + if loss_type in _RECON_LOSSES: loss = self._recon_loss( predictions["x_hat"], predictions["recon_target"], - _RECON_TYPES[loss_type], + loss_type, predictions.get("recon_mask"), ) return {loss_type: loss} @@ -188,18 +187,19 @@ def _sid_loss_impl( ) return {"commitment_loss": loss} elif loss_type == "sid_clip_loss": + + def scaled(p: torch.Tensor) -> torch.Tensor: + # clamp before exp so a large temperature can't overflow to +Inf. + return p.clamp(max=_LOGIT_SCALE_MAX).exp() + feats = { "image_embed": predictions["clip_image"], "text_embed": predictions["clip_text"], "image_embed_ori": predictions["clip_image_ori"], "text_embed_ori": predictions["clip_text_ori"], - "logit_scale_self": self._logit_scale_self.clamp( - max=_LOGIT_SCALE_MAX - ).exp(), - "logit_scale_cl": self._logit_scale_cl.clamp( - max=_LOGIT_SCALE_MAX - ).exp(), - "logit_scale": self._logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp(), + "logit_scale_self": scaled(self._logit_scale_self), + "logit_scale_cl": scaled(self._logit_scale_cl), + "logit_scale": scaled(self._logit_scale), } out = self._loss_modules["sid_clip_loss"](feats, predictions["clip_mask"]) return {"sid_clip_loss": out["clip_loss"]} @@ -210,10 +210,10 @@ def _recon_loss( self, x_hat: torch.Tensor, x: torch.Tensor, - recon_type: str, + recon_loss: str, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - """Reconstruction loss for ``recon_type`` ('mse'|'l1'|'cosine'). + """Reconstruction loss for a ``sid_loss`` recon variant. Returns the mean over all rows, or — when ``mask`` (a per-row bool) is given — the mean over only the masked-in rows (the mixed recon+CLIP path @@ -223,19 +223,20 @@ def _recon_loss( Args: x_hat (Tensor): reconstructed output, shape (B, D). x (Tensor): original input, shape (B, D). - recon_type (str): 'mse', 'l1' or 'cosine'. + recon_loss (str): the recon variant, one of ``_RECON_LOSSES`` + (``recon_l2_loss`` | ``recon_l1_loss`` | ``recon_cosine_loss``). mask (Tensor, optional): per-row bool; rows to include. """ - if recon_type == "mse": + if recon_loss == "recon_l2_loss": per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif recon_type == "l1": + elif recon_loss == "recon_l1_loss": per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - else: # 'cosine' + else: # "recon_cosine_loss" per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) if mask is None: return per_sample.mean() mask = mask.float() - return (per_sample * mask).sum() / mask.sum().clamp(min=1) + return div_no_nan((per_sample * mask).sum(), mask.sum()) def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 4c9069a1f..c114cbf25 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -78,15 +78,19 @@ def __init__( # CLIP is enabled by a `sid_clip_loss` entry in ModelConfig.losses, which # also carries the paired-feature names (data wiring). - self._clip_feature_name: Optional[str] = None - self._is_clip_pair_feature_name: Optional[str] = None - for loss_cfg in self._base_model_config.losses: - if loss_cfg.WhichOneof("sid_loss") == "sid_clip_loss": - self._clip_feature_name = loss_cfg.sid_clip_loss.clip_feature_name - self._is_clip_pair_feature_name = ( - loss_cfg.sid_clip_loss.is_clip_pair_feature_name - ) - self._use_clip = self._clip_feature_name is not None + clip_cfg = next( + ( + lc.sid_clip_loss + for lc in self._base_model_config.losses + if lc.WhichOneof("sid_loss") == "sid_clip_loss" + ), + None, + ) + self._use_clip = clip_cfg is not None + self._clip_feature_name = clip_cfg.clip_feature_name if clip_cfg else None + self._is_clip_pair_feature_name = ( + clip_cfg.is_clip_pair_feature_name if clip_cfg else None + ) embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -156,6 +160,10 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: predictions (dict): a dict of predicted result. """ embedding = self._extract_feature(batch) + if self._is_inference: + # Codes-only path: get_codes does just the residual walk (no decode, + # no commitment latents), so neither dual-path branch is needed. + return {"codes": self._quantizer.get_codes(self._encode(embedding))} if self._use_clip: return self._predict_mixed(embedding, batch) return self._predict_rqvae(embedding) @@ -164,8 +172,6 @@ def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: """Standard RQ-VAE: encode -> quantize -> decode.""" z_e = self._encode(embedding) quant = self._quantizer(z_e) - if self._is_inference: - return {"codes": quant.cluster_ids} return { "codes": quant.cluster_ids, "x_hat": self._decode(quant.quantized_embeddings), @@ -183,10 +189,6 @@ def _predict_mixed( averages over them; ``recon_mask`` (= non-CLIP rows) restricts the recon loss to reconstruction-only rows. """ - if self._is_inference: - z_e = self._encode(embedding) - return {"codes": self._quantizer(z_e).cluster_ids} - fea2 = self._extract_feature(batch, self._clip_feature_name) is_clip_pair_raw = self._extract_feature(batch, self._is_clip_pair_feature_name) clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 From 5a7370ffeafbbf1834357c5a812daca6dc20faf6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 06:11:18 +0000 Subject: [PATCH 105/129] [refactor] SID: generalize CLIP loss -> masked InfoNCE; model owns structure (E1) Round-5 review E1. Accept the rename; the "gather valid pairs outside the loss" half is declined (verified bit-exact no-op that breaks DDP collective-safety). - Rename MaskedCLIPLoss -> MaskedInfoNCELoss (clip_loss.py -> infonce_loss.py); generalize the modality-tied interface: image_embed/text_embed/*_ori -> embed_a/embed_b/*_ori, clip_mask -> pair_mask, output "clip_loss" -> "loss". Behavior-preserving (same self/ori/cl three-term masked InfoNCE). - Model owns structure (stop deriving forward topology from the loss list): the CLIP dual-encoder wiring (clip_feature_name, is_clip_pair_feature_name) moves from the SidClipLoss proto onto a new SidRqvae.clip_config field; SidClipLoss is now an empty objective marker. SidRqvae._use_clip reads clip_config, and a consistency check requires clip_config and a sid_clip_loss entry together. All 109 SID tests green. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/clip_loss.py | 187 ----------------- tzrec/loss/infonce_loss.py | 188 ++++++++++++++++++ ...clip_loss_test.py => infonce_loss_test.py} | 94 ++++----- tzrec/models/sid_model.py | 18 +- tzrec/models/sid_rqvae.py | 41 ++-- tzrec/models/sid_rqvae_test.py | 10 +- tzrec/protos/loss.proto | 8 +- tzrec/protos/models/sid_model.proto | 21 +- 8 files changed, 296 insertions(+), 271 deletions(-) delete mode 100644 tzrec/loss/clip_loss.py create mode 100644 tzrec/loss/infonce_loss.py rename tzrec/loss/{clip_loss_test.py => infonce_loss_test.py} (76%) diff --git a/tzrec/loss/clip_loss.py b/tzrec/loss/clip_loss.py deleted file mode 100644 index 493ecf006..000000000 --- a/tzrec/loss/clip_loss.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CLIP contrastive learning loss with distributed all-gather support.""" - -from typing import Dict, List, Optional - -import torch -import torch.distributed as dist -import torch.distributed.nn as dist_nn -from torch.nn import functional as F -from torch.nn.modules.loss import _Loss - - -class MaskedCLIPLoss(_Loss): - """Masked CLIP loss for mixed recon+clip batches. - - In a mixed batch, recon rows (clip_mask=False) must not contribute to the - CLIP loss, and recon columns must not serve as negatives. Row/column masks - achieve this without data-dependent branching (``torch.compile``-friendly). - - Input dict keys (all embeddings shape (B, input_dim)): - 'image_embed': reconstructed (decoder) output of feature 1 - 'text_embed': reconstructed (decoder) output of feature 2 - 'image_embed_ori': original embedding of feature 1 - 'text_embed_ori': original embedding of feature 2 - 'logit_scale_self': scalar temperature: recon-1 vs recon-2 - 'logit_scale_cl': scalar temperature: recon vs same-feature original - 'logit_scale': scalar temperature: recon vs counterpart original - - Output dict keys: - 'clip_loss': scalar mean of three contrastive losses (self/ori/cl) - """ - - def __init__(self) -> None: - super().__init__() - self.labels: Optional[torch.Tensor] = None - self.last_local_batch_size: Optional[int] = None - self._rank = dist.get_rank() if dist.is_initialized() else 0 - - @staticmethod - def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: - """All-gather tensors across workers with gradient support. - - Single-process: returns the inputs unchanged. Multi-process: uses the - built-in differentiable ``torch.distributed.nn.functional.all_gather``, - so no custom ``autograd.Function`` is needed. - - Args: - tensors (List[Tensor]): list of tensors to gather. - - Returns: - List[Tensor]: gathered tensors, each (world_size * B, ...). - """ - if not dist.is_initialized() or dist.get_world_size() == 1: - return tensors - gathered: List[torch.Tensor] = [] - for tensor in tensors: - tensor_all = dist_nn.all_gather(tensor) # differentiable, per rank - gathered.append(torch.cat(tensor_all, dim=0)) - return gathered - - @staticmethod - def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor: - """All-gather bool mask across distributed workers.""" - if not dist.is_initialized() or dist.get_world_size() == 1: - return mask - mask_list = [torch.zeros_like(mask) for _ in range(dist.get_world_size())] - dist.all_gather(mask_list, mask) - return torch.cat(mask_list, dim=0) - - def _masked_cross_entropy( - self, - logits_i: torch.Tensor, - logits_t: torch.Tensor, - safe_labels: torch.Tensor, - clip_mask: torch.Tensor, - ) -> torch.Tensor: - """Masked cross-entropy on column-masked logits, row-masked average. - - Args: - logits_i: (B, B_global) column-masked logits (image branch). - logits_t: (B, B_global) column-masked logits (text branch). - safe_labels: (B,) labels with recon rows fallback to safe col. - clip_mask: (B,) bool, True = clip row. - """ - ce_i = F.cross_entropy(logits_i, safe_labels, reduction="none") - ce_t = F.cross_entropy(logits_t, safe_labels, reduction="none") - # Backstop against a non-finite upstream logit (e.g. overflowed scale). - ce_i = torch.nan_to_num(ce_i, nan=0.0) - ce_t = torch.nan_to_num(ce_t, nan=0.0) - - # Only clip rows contribute; clamp(min=1) keeps a no-clip batch at 0. - n_valid = clip_mask.float().sum().clamp(min=1) - return ((ce_i + ce_t) * clip_mask.float()).sum() / (2 * n_valid) - - def forward( - self, - outputs: Dict[str, torch.Tensor], - clip_mask: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - """Forward with mask. - - Args: - outputs: feature dict, see class docstring. - clip_mask: (B,) bool, True = clip sample. - """ - image_embed = outputs["image_embed"] - text_embed = outputs["text_embed"] - image_embed_ori = outputs["image_embed_ori"] - text_embed_ori = outputs["text_embed_ori"] - logit_scale = outputs["logit_scale"] - logit_scale_self = outputs["logit_scale_self"] - logit_scale_cl = outputs["logit_scale_cl"] - - local_batch_size = image_embed.size(0) - - # Update labels when batch size changes (multi-GPU offset) - if local_batch_size != self.last_local_batch_size: - self.labels = local_batch_size * self._rank + torch.arange( - local_batch_size, device=image_embed.device - ) - self.last_local_batch_size = local_batch_size - - # L2 normalize quantized features - image_embed = F.normalize(image_embed, dim=-1, p=2) - text_embed = F.normalize(text_embed, dim=-1, p=2) - - # All-gather across GPUs (with gradient support) - image_embed_all, text_embed_all = self._all_gather_with_grad( - [image_embed, text_embed] - ) - image_embed_all_ori, text_embed_all_ori = self._all_gather_with_grad( - [image_embed_ori, text_embed_ori] - ) - - # --- Compute six groups of logits (image/text × self/ori/cl) --- - logits_img_self = logit_scale_self * image_embed @ text_embed_all.t() - logits_txt_self = logit_scale_self * text_embed @ image_embed_all.t() - - logits_img_ori = logit_scale * image_embed @ text_embed_all_ori.t() - logits_txt_ori = logit_scale * text_embed @ image_embed_all_ori.t() - - logits_img_cl = logit_scale_cl * image_embed @ image_embed_all_ori.t() - logits_txt_cl = logit_scale_cl * text_embed @ text_embed_all_ori.t() - - # Mask recon columns out of the negatives with the dtype's most negative - # finite value: below any real logit (masks like -inf), but finite so an - # all-recon row gives a finite CE/grad instead of 0*NaN. - clip_mask_all = self._gather_bool_mask(clip_mask) - col_mask = (~clip_mask_all).unsqueeze(0) # (1, B_global) - neg_fill = torch.finfo(logits_img_self.dtype).min - - logits_img_self = logits_img_self.masked_fill(col_mask, neg_fill) - logits_txt_self = logits_txt_self.masked_fill(col_mask, neg_fill) - logits_img_ori = logits_img_ori.masked_fill(col_mask, neg_fill) - logits_txt_ori = logits_txt_ori.masked_fill(col_mask, neg_fill) - logits_img_cl = logits_img_cl.masked_fill(col_mask, neg_fill) - logits_txt_cl = logits_txt_cl.masked_fill(col_mask, neg_fill) - - # --- Safe labels: recon rows fallback to first clip column --- - labels = self.labels - fallback = clip_mask.long().argmax() # first clip sample index - safe_labels = torch.where(clip_mask, labels, fallback.expand_as(labels)) - - # --- Masked CE for three loss groups --- - loss_self = self._masked_cross_entropy( - logits_img_self, logits_txt_self, safe_labels, clip_mask - ) - loss_ori = self._masked_cross_entropy( - logits_img_ori, logits_txt_ori, safe_labels, clip_mask - ) - loss_cl = self._masked_cross_entropy( - logits_img_cl, logits_txt_cl, safe_labels, clip_mask - ) - - clip_loss = (loss_self + loss_ori + loss_cl) / 3 - - return {"clip_loss": clip_loss} diff --git a/tzrec/loss/infonce_loss.py b/tzrec/loss/infonce_loss.py new file mode 100644 index 000000000..703c737fd --- /dev/null +++ b/tzrec/loss/infonce_loss.py @@ -0,0 +1,188 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Masked InfoNCE contrastive loss with distributed all-gather support.""" + +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.distributed.nn as dist_nn +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + + +class MaskedInfoNCELoss(_Loss): + """Masked InfoNCE contrastive loss for mixed (paired + non-paired) batches. + + Modality-agnostic: aligns two reconstructed "views" (``embed_a``/``embed_b``) + against each other and against their originals (``embed_a_ori``/ + ``embed_b_ori``) with three symmetric InfoNCE terms (self/ori/cl). In a mixed + batch, non-pair rows (``pair_mask=False``) must not contribute and must not + serve as negatives; row/column masks achieve this without data-dependent + branching (``torch.compile``-friendly). + + Input dict keys (all embeddings shape (B, dim)): + 'embed_a': reconstructed (decoder) output of view a + 'embed_b': reconstructed (decoder) output of view b + 'embed_a_ori': original embedding of view a + 'embed_b_ori': original embedding of view b + 'logit_scale_self': scalar temperature: recon-a vs recon-b + 'logit_scale_cl': scalar temperature: recon vs same-view original + 'logit_scale': scalar temperature: recon vs counterpart original + + Output dict keys: + 'loss': scalar mean of the three contrastive losses (self/ori/cl) + """ + + def __init__(self) -> None: + super().__init__() + self.labels: Optional[torch.Tensor] = None + self.last_local_batch_size: Optional[int] = None + self._rank = dist.get_rank() if dist.is_initialized() else 0 + + @staticmethod + def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """All-gather tensors across workers with gradient support. + + Single-process: returns the inputs unchanged. Multi-process: uses the + built-in differentiable ``torch.distributed.nn.functional.all_gather``, + so no custom ``autograd.Function`` is needed. + + Args: + tensors (List[Tensor]): list of tensors to gather. + + Returns: + List[Tensor]: gathered tensors, each (world_size * B, ...). + """ + if not dist.is_initialized() or dist.get_world_size() == 1: + return tensors + gathered: List[torch.Tensor] = [] + for tensor in tensors: + tensor_all = dist_nn.all_gather(tensor) # differentiable, per rank + gathered.append(torch.cat(tensor_all, dim=0)) + return gathered + + @staticmethod + def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor: + """All-gather bool mask across distributed workers.""" + if not dist.is_initialized() or dist.get_world_size() == 1: + return mask + mask_list = [torch.zeros_like(mask) for _ in range(dist.get_world_size())] + dist.all_gather(mask_list, mask) + return torch.cat(mask_list, dim=0) + + def _masked_cross_entropy( + self, + logits_a: torch.Tensor, + logits_b: torch.Tensor, + safe_labels: torch.Tensor, + pair_mask: torch.Tensor, + ) -> torch.Tensor: + """Masked cross-entropy on column-masked logits, row-masked average. + + Args: + logits_a: (B, B_global) column-masked logits (view-a branch). + logits_b: (B, B_global) column-masked logits (view-b branch). + safe_labels: (B,) labels with non-pair rows fallback to a safe col. + pair_mask: (B,) bool, True = pair row. + """ + ce_a = F.cross_entropy(logits_a, safe_labels, reduction="none") + ce_b = F.cross_entropy(logits_b, safe_labels, reduction="none") + # Backstop against a non-finite upstream logit (e.g. overflowed scale). + ce_a = torch.nan_to_num(ce_a, nan=0.0) + ce_b = torch.nan_to_num(ce_b, nan=0.0) + + # Only pair rows contribute; clamp(min=1) keeps a no-pair batch at 0. + n_valid = pair_mask.float().sum().clamp(min=1) + return ((ce_a + ce_b) * pair_mask.float()).sum() / (2 * n_valid) + + def forward( + self, + outputs: Dict[str, torch.Tensor], + pair_mask: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + """Forward with the pair mask. + + Args: + outputs: feature dict, see class docstring. + pair_mask: (B,) bool, True = contrastive-pair sample. + """ + embed_a = outputs["embed_a"] + embed_b = outputs["embed_b"] + embed_a_ori = outputs["embed_a_ori"] + embed_b_ori = outputs["embed_b_ori"] + logit_scale = outputs["logit_scale"] + logit_scale_self = outputs["logit_scale_self"] + logit_scale_cl = outputs["logit_scale_cl"] + + local_batch_size = embed_a.size(0) + + # Update labels when batch size changes (multi-GPU offset) + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * self._rank + torch.arange( + local_batch_size, device=embed_a.device + ) + self.last_local_batch_size = local_batch_size + + # L2 normalize the reconstructed features + embed_a = F.normalize(embed_a, dim=-1, p=2) + embed_b = F.normalize(embed_b, dim=-1, p=2) + + # All-gather across GPUs (with gradient support) + embed_a_all, embed_b_all = self._all_gather_with_grad([embed_a, embed_b]) + embed_a_all_ori, embed_b_all_ori = self._all_gather_with_grad( + [embed_a_ori, embed_b_ori] + ) + + # --- Compute six groups of logits (a/b × self/ori/cl) --- + logits_a_self = logit_scale_self * embed_a @ embed_b_all.t() + logits_b_self = logit_scale_self * embed_b @ embed_a_all.t() + + logits_a_ori = logit_scale * embed_a @ embed_b_all_ori.t() + logits_b_ori = logit_scale * embed_b @ embed_a_all_ori.t() + + logits_a_cl = logit_scale_cl * embed_a @ embed_a_all_ori.t() + logits_b_cl = logit_scale_cl * embed_b @ embed_b_all_ori.t() + + # Mask non-pair columns out of the negatives with the dtype's most negative + # finite value: below any real logit (masks like -inf), but finite so an + # all-non-pair row gives a finite CE/grad instead of 0*NaN. + pair_mask_all = self._gather_bool_mask(pair_mask) + col_mask = (~pair_mask_all).unsqueeze(0) # (1, B_global) + neg_fill = torch.finfo(logits_a_self.dtype).min + + logits_a_self = logits_a_self.masked_fill(col_mask, neg_fill) + logits_b_self = logits_b_self.masked_fill(col_mask, neg_fill) + logits_a_ori = logits_a_ori.masked_fill(col_mask, neg_fill) + logits_b_ori = logits_b_ori.masked_fill(col_mask, neg_fill) + logits_a_cl = logits_a_cl.masked_fill(col_mask, neg_fill) + logits_b_cl = logits_b_cl.masked_fill(col_mask, neg_fill) + + # --- Safe labels: non-pair rows fallback to the first pair column --- + labels = self.labels + fallback = pair_mask.long().argmax() # first pair sample index + safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels)) + + # --- Masked CE for three loss groups --- + loss_self = self._masked_cross_entropy( + logits_a_self, logits_b_self, safe_labels, pair_mask + ) + loss_ori = self._masked_cross_entropy( + logits_a_ori, logits_b_ori, safe_labels, pair_mask + ) + loss_cl = self._masked_cross_entropy( + logits_a_cl, logits_b_cl, safe_labels, pair_mask + ) + + loss = (loss_self + loss_ori + loss_cl) / 3 + + return {"loss": loss} diff --git a/tzrec/loss/clip_loss_test.py b/tzrec/loss/infonce_loss_test.py similarity index 76% rename from tzrec/loss/clip_loss_test.py rename to tzrec/loss/infonce_loss_test.py index 4436f6b33..a395d1845 100644 --- a/tzrec/loss/clip_loss_test.py +++ b/tzrec/loss/infonce_loss_test.py @@ -17,82 +17,82 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tzrec.loss.clip_loss import MaskedCLIPLoss +from tzrec.loss.infonce_loss import MaskedInfoNCELoss from tzrec.utils import misc_util class AllGatherWithGradTest(unittest.TestCase): def test_single_process_identity(self) -> None: a, b = torch.randn(3, 4), torch.randn(3, 4) - out = MaskedCLIPLoss._all_gather_with_grad([a, b]) + out = MaskedInfoNCELoss._all_gather_with_grad([a, b]) self.assertIs(out[0], a) self.assertIs(out[1], b) -class MaskedCLIPLossTest(unittest.TestCase): +class MaskedInfoNCELossTest(unittest.TestCase): """Single-process tests for the masked CLIP loss.""" def _features(self, B: int, D: int) -> dict: torch.manual_seed(0) scale = torch.tensor(np.log(1 / 0.07)).exp() return { - "image_embed": torch.randn(B, D, requires_grad=True), - "text_embed": torch.randn(B, D, requires_grad=True), - "image_embed_ori": torch.randn(B, D), - "text_embed_ori": torch.randn(B, D), + "embed_a": torch.randn(B, D, requires_grad=True), + "embed_b": torch.randn(B, D, requires_grad=True), + "embed_a_ori": torch.randn(B, D), + "embed_b_ori": torch.randn(B, D), "logit_scale_self": scale, "logit_scale_cl": scale, "logit_scale": scale, } def test_forward_all_clip_finite(self) -> None: - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() feats = self._features(6, 8) mask = torch.ones(6, dtype=torch.bool) out = loss_fn(feats, mask) - self.assertIn("clip_loss", out) - self.assertTrue(torch.isfinite(out["clip_loss"])) - self.assertGreater(out["clip_loss"].item(), 0.0) + self.assertIn("loss", out) + self.assertTrue(torch.isfinite(out["loss"])) + self.assertGreater(out["loss"].item(), 0.0) def test_all_recon_mask_zero_loss(self) -> None: - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() feats = self._features(6, 8) mask = torch.zeros(6, dtype=torch.bool) # no clip rows out = loss_fn(feats, mask) # No clip rows -> masked average is exactly zero (and finite). - self.assertTrue(torch.isfinite(out["clip_loss"])) - self.assertAlmostEqual(out["clip_loss"].item(), 0.0, places=6) + self.assertTrue(torch.isfinite(out["loss"])) + self.assertAlmostEqual(out["loss"].item(), 0.0, places=6) def test_all_recon_mask_finite_gradient(self) -> None: # Regression: with float("-inf") column fill an all-recon batch produced # a NaN gradient (0 * NaN) that survived the row mask. The finite fill # must keep the backward finite (and zero, since no clip row contributes). - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() feats = self._features(6, 8) mask = torch.zeros(6, dtype=torch.bool) - loss_fn(feats, mask)["clip_loss"].backward() - grad = feats["image_embed"].grad + loss_fn(feats, mask)["loss"].backward() + grad = feats["embed_a"].grad self.assertIsNotNone(grad) self.assertTrue(torch.isfinite(grad).all()) self.assertAlmostEqual(grad.abs().sum().item(), 0.0, places=6) def test_backward_flows_to_embeddings(self) -> None: - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() feats = self._features(6, 8) mask = torch.ones(6, dtype=torch.bool) - loss_fn(feats, mask)["clip_loss"].backward() - self.assertIsNotNone(feats["image_embed"].grad) - self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + loss_fn(feats, mask)["loss"].backward() + self.assertIsNotNone(feats["embed_a"].grad) + self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) def test_recon_columns_excluded_from_negatives(self) -> None: """A recon row's embedding must not affect a clip row's loss. Recon rows are dropped as queries (row mask) AND their columns are masked out of the negatives (col_mask). Perturbing the recon rows of - EVERY column operand — ``text_embed`` (the self group) and both + EVERY column operand — ``embed_b`` (the self group) and both ``*_ori`` operands (the ori/cl groups) — must leave the clip rows' loss unchanged; a dropped or inverted ``col_mask`` on any group would fail. - Distinct ``image_embed_ori`` / ``text_embed_ori`` so the ori/cl masking + Distinct ``embed_a_ori`` / ``embed_b_ori`` so the ori/cl masking is actually exercised (not hidden by a shared tensor). """ torch.manual_seed(0) @@ -103,24 +103,24 @@ def test_recon_columns_excluded_from_negatives(self) -> None: def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): return { - "image_embed": img, - "text_embed": txt, - "image_embed_ori": img_ori, - "text_embed_ori": txt_ori, + "embed_a": img, + "embed_b": txt, + "embed_a_ori": img_ori, + "embed_b_ori": txt_ori, "logit_scale_self": scale, "logit_scale_cl": scale, "logit_scale": scale, } txt, txt_ori, img_ori = (torch.randn(B, D) for _ in range(3)) - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() loss_fn.eval() - base = loss_fn(feats(txt, txt_ori, img_ori), mask)["clip_loss"] + base = loss_fn(feats(txt, txt_ori, img_ori), mask)["loss"] # Perturb ONLY the recon rows of every column operand that feeds negatives. txt2, txt_ori2, img_ori2 = txt.clone(), txt_ori.clone(), img_ori.clone() for t in (txt2, txt_ori2, img_ori2): t[2:] = torch.randn(2, D) - after = loss_fn(feats(txt2, txt_ori2, img_ori2), mask)["clip_loss"] + after = loss_fn(feats(txt2, txt_ori2, img_ori2), mask)["loss"] torch.testing.assert_close(base, after) def test_mask_holds_under_large_scale(self) -> None: @@ -128,27 +128,27 @@ def test_mask_holds_under_large_scale(self) -> None: # hardcoded -1e4, so masking holds even when logit_scale is large and # the *_ori operands are un-normalized (real logits can dwarf 1e4). # Loss/grad must stay finite and acc valid; eval exercises the argmax. - loss_fn = MaskedCLIPLoss() + loss_fn = MaskedInfoNCELoss() loss_fn.eval() feats = self._features(6, 8) big = torch.tensor(3000.0) feats["logit_scale"] = big feats["logit_scale_self"] = big feats["logit_scale_cl"] = big - feats["image_embed_ori"] = feats["image_embed_ori"] * 50 - feats["text_embed_ori"] = feats["text_embed_ori"] * 50 + feats["embed_a_ori"] = feats["embed_a_ori"] * 50 + feats["embed_b_ori"] = feats["embed_b_ori"] * 50 mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) out = loss_fn(feats, mask) - self.assertTrue(torch.isfinite(out["clip_loss"])) + self.assertTrue(torch.isfinite(out["loss"])) loss_fn.train() - feats["image_embed"].grad = None - loss_fn(feats, mask)["clip_loss"].backward() - self.assertTrue(torch.isfinite(feats["image_embed"].grad).all()) + feats["embed_a"].grad = None + loss_fn(feats, mask)["loss"].backward() + self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) # --- Multi-process tests for the CLIP distributed all-gather path. --- # Validates ``_all_gather_with_grad`` (built on the differentiable -# ``torch.distributed.nn.functional.all_gather``) and ``MaskedCLIPLoss`` across +# ``torch.distributed.nn.functional.all_gather``) and ``MaskedInfoNCELoss`` across # ranks. Uses NCCL on GPU when >=2 devices are available (the production path the # reviewer cared about), else falls back to gloo/CPU, so it runs on a multi-GPU # box and in CPU CI alike. @@ -174,7 +174,7 @@ def _all_gather_worker(rank: int, world_size: int, port: int) -> None: device = _init(rank, world_size, port) # Each rank holds a distinct, rank-identifying tensor. x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) - gathered = MaskedCLIPLoss._all_gather_with_grad([x])[0] + gathered = MaskedInfoNCELoss._all_gather_with_grad([x])[0] # Forward: gathered is (world_size*2, 3); rank r contributes rows # [2r : 2r+2] all equal to (r+1). @@ -202,24 +202,24 @@ def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: B, D = 4, 8 scale = torch.tensor(np.log(1 / 0.07)).exp().to(device) feats = { - "image_embed": torch.randn(B, D, device=device, requires_grad=True), - "text_embed": torch.randn(B, D, device=device, requires_grad=True), - "image_embed_ori": torch.randn(B, D, device=device), - "text_embed_ori": torch.randn(B, D, device=device), + "embed_a": torch.randn(B, D, device=device, requires_grad=True), + "embed_b": torch.randn(B, D, device=device, requires_grad=True), + "embed_a_ori": torch.randn(B, D, device=device), + "embed_b_ori": torch.randn(B, D, device=device), "logit_scale_self": scale, "logit_scale_cl": scale, "logit_scale": scale, } mask = torch.ones(B, dtype=torch.bool, device=device) - loss_fn = MaskedCLIPLoss().to(device) + loss_fn = MaskedInfoNCELoss().to(device) out = loss_fn(feats, mask) - clip_loss = out["clip_loss"] + clip_loss = out["loss"] assert torch.isfinite(clip_loss).all(), f"rank{rank}: non-finite clip_loss" assert clip_loss.item() > 0.0, f"rank{rank}: clip_loss not positive" clip_loss.backward() - g = feats["image_embed"].grad + g = feats["embed_a"].grad assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" dist.destroy_process_group() @@ -238,7 +238,7 @@ def _run(target) -> None: raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") -class ClipLossDistTest(unittest.TestCase): +class InfoNCEDistTest(unittest.TestCase): """2-rank tests for the CLIP distributed collectives.""" def test_all_gather_with_grad(self) -> None: diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 0bb1a1799..50292c3aa 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -21,8 +21,8 @@ from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature -from tzrec.loss.clip_loss import MaskedCLIPLoss from tzrec.loss.commitment_loss import CommitmentLoss +from tzrec.loss.infonce_loss import MaskedInfoNCELoss from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel @@ -138,11 +138,11 @@ def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: commitment_type=cfg.commitment_type, ) elif loss_type == "sid_clip_loss": - # The three learnable CLIP temperatures + the masked-CLIP module. + # The three learnable contrastive temperatures + the InfoNCE module. self._logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) self._logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) self._logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) - self._loss_modules["sid_clip_loss"] = MaskedCLIPLoss() + self._loss_modules["sid_clip_loss"] = MaskedInfoNCELoss() else: raise ValueError( f"LossConfig for a SID model must set a sid_loss variant, " @@ -193,16 +193,16 @@ def scaled(p: torch.Tensor) -> torch.Tensor: return p.clamp(max=_LOGIT_SCALE_MAX).exp() feats = { - "image_embed": predictions["clip_image"], - "text_embed": predictions["clip_text"], - "image_embed_ori": predictions["clip_image_ori"], - "text_embed_ori": predictions["clip_text_ori"], + "embed_a": predictions["embed_a"], + "embed_b": predictions["embed_b"], + "embed_a_ori": predictions["embed_a_ori"], + "embed_b_ori": predictions["embed_b_ori"], "logit_scale_self": scaled(self._logit_scale_self), "logit_scale_cl": scaled(self._logit_scale_cl), "logit_scale": scaled(self._logit_scale), } - out = self._loss_modules["sid_clip_loss"](feats, predictions["clip_mask"]) - return {"sid_clip_loss": out["clip_loss"]} + out = self._loss_modules["sid_clip_loss"](feats, predictions["pair_mask"]) + return {"sid_clip_loss": out["loss"]} else: raise ValueError(f"unsupported sid_loss variant: {loss_type!r}") diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index c114cbf25..0e2290a77 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -76,21 +76,27 @@ def __init__( cfg = self._model_config # SidRqvae proto message - # CLIP is enabled by a `sid_clip_loss` entry in ModelConfig.losses, which - # also carries the paired-feature names (data wiring). - clip_cfg = next( - ( - lc.sid_clip_loss - for lc in self._base_model_config.losses - if lc.WhichOneof("sid_loss") == "sid_clip_loss" - ), - None, + # The CLIP-style dual-encoder structure (which paired feature to encode, + # the dual path) is declared on the MODEL proto (`clip_config`); the + # contrastive OBJECTIVE is enabled by a `sid_clip_loss` entry in + # ModelConfig.losses. The two must be set together. + self._use_clip = cfg.HasField("clip_config") + self._clip_feature_name = ( + cfg.clip_config.clip_feature_name if self._use_clip else None ) - self._use_clip = clip_cfg is not None - self._clip_feature_name = clip_cfg.clip_feature_name if clip_cfg else None self._is_clip_pair_feature_name = ( - clip_cfg.is_clip_pair_feature_name if clip_cfg else None + cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None ) + has_clip_obj = any( + lc.WhichOneof("sid_loss") == "sid_clip_loss" + for lc in self._base_model_config.losses + ) + if self._use_clip != has_clip_obj: + raise ValueError( + "clip_config (model structure) and a sid_clip_loss entry in " + "losses (the objective) must be set together; got " + f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" + ) embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -208,9 +214,10 @@ def _predict_mixed( "recon_mask": ~clip_mask, "encoder_out": torch.cat([z_e1, z_e2], dim=0), "latents": torch.cat([quant1.latents, quant2.latents], dim=0), - "clip_image": x_hat1, - "clip_text": x_hat2, - "clip_image_ori": embedding, - "clip_text_ori": fea2, - "clip_mask": clip_mask, + # generic contrastive operands (view a = main, view b = paired): + "embed_a": x_hat1, + "embed_b": x_hat2, + "embed_a_ori": embedding, + "embed_b_ori": fea2, + "pair_mask": clip_mask, } diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index ae53241ef..5b57c8b54 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -60,9 +60,10 @@ def _commitment_cfg( def _clip_cfg() -> loss_pb2.LossConfig: + # The contrastive objective marker (empty); the paired-feature wiring lives + # on the model proto (SidRqvae.clip_config), set in _create_model. lc = loss_pb2.LossConfig() - lc.sid_clip_loss.clip_feature_name = "image_emb" - lc.sid_clip_loss.is_clip_pair_feature_name = "is_clip_pair" + lc.sid_clip_loss.SetInParent() return lc @@ -89,6 +90,9 @@ def _create_model( ) losses = [_recon_loss_cfg(recon), _commitment_cfg()] if use_clip: + # structure on the model proto; objective marker in losses. + sid_rqvae_cfg.clip_config.clip_feature_name = "image_emb" + sid_rqvae_cfg.clip_config.is_clip_pair_feature_name = "is_clip_pair" losses.append(_clip_cfg()) # SID models read the item-embedding dense feature directly from the @@ -186,7 +190,7 @@ def test_rqvae_clip_mode(self) -> None: predictions = model.predict(batch) self.assertIn("codes", predictions) self.assertIn("x_hat", predictions) - self.assertIn("clip_image", predictions) + self.assertIn("embed_a", predictions) self.assertEqual(predictions["codes"].shape[0], B) losses = model.loss(predictions, batch) diff --git a/tzrec/protos/loss.proto b/tzrec/protos/loss.proto index 466479cb2..b60fcc668 100644 --- a/tzrec/protos/loss.proto +++ b/tzrec/protos/loss.proto @@ -41,12 +41,10 @@ message CommitmentLoss { optional string commitment_type = 2 [default = "l2"]; } -// CLIP contrastive loss for the dual (image/text) reconstruction path. +// Enables the contrastive (masked InfoNCE) objective for a CLIP-style SID model. +// The paired-feature wiring lives on the model (SidRqvae.clip_config); this just +// turns the objective on (any loss hyperparameters would go here). message SidClipLoss { - // Name of the second (paired) embedding feature inside the input Batch. - required string clip_feature_name = 1; - // Name of the per-row float feature flagging CLIP-pair rows (>0.5 = pair). - required string is_clip_pair_feature_name = 2; } message BinaryCrossEntropy { diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 4eb3ffffd..f3e02436c 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -26,6 +26,16 @@ message SinkhornConfig { optional bool enabled = 3 [default = true]; } +// CLIP-style dual-encoder wiring for SidRqvae: which paired feature to encode +// and which column flags the contrastive-pair rows. This is model structure / +// input contract (declared on the model), not loss config. +message ClipConfig { + // Name of the second (paired) embedding feature inside the input Batch. + required string clip_feature_name = 1; + // Name of the per-row float feature flagging pair rows (>0.5 = pair). + required string is_clip_pair_feature_name = 2; +} + message SidRqvae { // === Network structure === // Input embedding dimension. @@ -67,10 +77,15 @@ message SidRqvae { // omitted: enabled with iters=5, epsilon=10.0. Include the sub-config // to override params; set ``enabled: false`` inside it to disable. optional SinkhornConfig sinkhorn_config = 15; + // CLIP-style dual-encoder structure: when set, the model encodes a second + // (paired) feature and runs the contrastive path. This declares the model's + // input contract + topology; the contrastive OBJECTIVE is enabled separately + // by a `sid_clip_loss` entry in ModelConfig.losses (both must be set together). + optional ClipConfig clip_config = 16; - // Reconstruction, commitment and (optional) CLIP losses are configured via - // ModelConfig.losses (the LossConfig ``sid_loss`` oneof), not on this - // message — see tzrec/protos/loss.proto. + // Reconstruction, commitment and (optional) contrastive losses are configured + // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the CLIP + // feature wiring above lives on this message. // Name of the item embedding feature inside the input Batch. optional string embedding_feature_name = 40 [default = "item_emb"]; From 2002daf3442d3b4251570eddfd198862047f0be3 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 06:29:18 +0000 Subject: [PATCH 106/129] [refactor] SID: extract _masked_mean; inline recon distance into _sid_loss_impl Drop the _recon_loss method: its per-row masked reduction becomes a module-level `_masked_mean(per_sample, mask)` helper, and the recon distance (mse/l1/cosine branch) inlines directly into `_sid_loss_impl`'s recon dispatch. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 68 ++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 40 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 50292c3aa..0f1afad9d 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -36,10 +36,29 @@ # CLIP temperature init (reference CLIP: log(1 / 0.07)). _LOGIT_SCALE_INIT = float(np.log(1 / 0.07)) -# sid_loss reconstruction variants (``_recon_loss`` branches on these directly). +# sid_loss reconstruction variants (``_sid_loss_impl`` branches on these). _RECON_LOSSES = frozenset(("recon_l2_loss", "recon_l1_loss", "recon_cosine_loss")) +def _masked_mean( + per_sample: torch.Tensor, mask: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Mean of a per-row loss over the masked-in rows (all rows if ``mask`` None). + + The mixed recon+CLIP path applies the reconstruction loss to recon rows only; + the masked mean divides by the valid-row count (``div_no_nan`` keeps an empty + mask at 0). No data-dependent branching → ``torch.compile``-friendly. + + Args: + per_sample (Tensor): per-row loss, shape (B,). + mask (Tensor, optional): per-row bool; rows to include. + """ + if mask is None: + return per_sample.mean() + mask = mask.float() + return div_no_nan((per_sample * mask).sum(), mask.sum()) + + class BaseSidModel(BaseModel): """Shared base for semantic-ID (SID) generation models. @@ -174,13 +193,14 @@ def _sid_loss_impl( """Compute one ``sid_loss`` term from ``predictions``.""" loss_type = loss_cfg.WhichOneof("sid_loss") if loss_type in _RECON_LOSSES: - loss = self._recon_loss( - predictions["x_hat"], - predictions["recon_target"], - loss_type, - predictions.get("recon_mask"), - ) - return {loss_type: loss} + x_hat, x = predictions["x_hat"], predictions["recon_target"] + if loss_type == "recon_l2_loss": + per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) + elif loss_type == "recon_l1_loss": + per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + else: # "recon_cosine_loss" + per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) + return {loss_type: _masked_mean(per_sample, predictions.get("recon_mask"))} elif loss_type == "commitment_loss": loss = self._loss_modules["commitment_loss"]( predictions["encoder_out"], predictions["latents"] @@ -206,38 +226,6 @@ def scaled(p: torch.Tensor) -> torch.Tensor: else: raise ValueError(f"unsupported sid_loss variant: {loss_type!r}") - def _recon_loss( - self, - x_hat: torch.Tensor, - x: torch.Tensor, - recon_loss: str, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Reconstruction loss for a ``sid_loss`` recon variant. - - Returns the mean over all rows, or — when ``mask`` (a per-row bool) is - given — the mean over only the masked-in rows (the mixed recon+CLIP path - applies recon loss to recon rows only). No data-dependent branching, so - it stays ``torch.compile``-friendly. - - Args: - x_hat (Tensor): reconstructed output, shape (B, D). - x (Tensor): original input, shape (B, D). - recon_loss (str): the recon variant, one of ``_RECON_LOSSES`` - (``recon_l2_loss`` | ``recon_l1_loss`` | ``recon_cosine_loss``). - mask (Tensor, optional): per-row bool; rows to include. - """ - if recon_loss == "recon_l2_loss": - per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif recon_loss == "recon_l1_loss": - per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - else: # "recon_cosine_loss" - per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) - if mask is None: - return per_sample.mean() - mask = mask.float() - return div_no_nan((per_sample * mask).sum(), mask.sum()) - def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. From 0c2f6e1f0acc5e55636334b4b0ed64c834b36894 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 07:12:45 +0000 Subject: [PATCH 107/129] [refactor] SID: bind recon loss at init + merge ReconL2/L1/Cosine -> ReconLoss - Bind the recon loss like commitment/clip (init binds, loss calls): a recon_loss(recon_type) factory returns the per-row distance fn, bound into self._recon_fn in _init_sid_loss_impl and called in _sid_loss_impl. Removes the per-batch if/elif distance branching and the _RECON_LOSSES set. - Proto: collapse the three empty ReconL2Loss/ReconL1Loss/ReconCosineLoss marker messages into a single ReconLoss { recon_type = "l2"|"l1"|"cos" }, mirroring CommitmentLoss { commitment_type }. The sid_loss oneof now has one recon_loss variant; the loss is keyed "recon_loss". - Update sid_rqvae_mock.config + sid_rqvae_test to the merged form. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 53 +++++++++++++++-------- tzrec/models/sid_rqvae_test.py | 38 +++++++--------- tzrec/protos/loss.proto | 20 +++------ tzrec/tests/configs/sid_rqvae_mock.config | 3 +- 4 files changed, 61 insertions(+), 53 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 0f1afad9d..f77e3bc49 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -11,7 +11,7 @@ """BaseSidModel: shared base for semantic-ID generation models.""" -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import numpy as np import torch @@ -36,8 +36,25 @@ # CLIP temperature init (reference CLIP: log(1 / 0.07)). _LOGIT_SCALE_INIT = float(np.log(1 / 0.07)) -# sid_loss reconstruction variants (``_sid_loss_impl`` branches on these). -_RECON_LOSSES = frozenset(("recon_l2_loss", "recon_l1_loss", "recon_cosine_loss")) + +def recon_loss( + recon_type: str, +) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: + """Per-row reconstruction-distance fn for the configured ``recon_type``. + + Args: + recon_type (str): the distance, ``"l2"`` (mse), ``"l1"`` or ``"cos"``. + + Returns: + Callable: ``f(x_hat, x) -> (B,)`` per-row reconstruction distance. + """ + if recon_type == "l2": + return lambda x_hat, x: F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) + if recon_type == "l1": + return lambda x_hat, x: F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + if recon_type == "cos": + return lambda x_hat, x: 1 - F.cosine_similarity(x_hat, x, dim=-1) + raise ValueError(f"recon_type must be 'l2', 'l1' or 'cos', got {recon_type!r}") def _masked_mean( @@ -138,17 +155,21 @@ def init_loss(self) -> None: Each ``LossConfig`` sets one ``sid_loss`` oneof variant (a reconstruction loss, the commitment loss, or the CLIP loss). Mirrors ``RankModel``: the - config drives which loss modules are registered, and :meth:`loss` - computes them from ``predictions``. + config drives what is bound here, and :meth:`loss` computes them from + ``predictions``. The reconstruction loss binds a per-row distance fn into + ``_recon_fn``; commitment/CLIP register modules into ``_loss_modules``. """ + self._recon_fn: Optional[ + Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + ] = None for loss_cfg in self._base_model_config.losses: self._init_sid_loss_impl(loss_cfg) def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: - """Register the module (if any) for one ``sid_loss`` config.""" + """Bind the loss (a recon fn or a module) for one ``sid_loss`` config.""" loss_type = loss_cfg.WhichOneof("sid_loss") - if loss_type in _RECON_LOSSES: - return # reconstruction losses are functional (no module) + if loss_type == "recon_loss": + self._recon_fn = recon_loss(loss_cfg.recon_loss.recon_type) elif loss_type == "commitment_loss": cfg = loss_cfg.commitment_loss latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) @@ -192,15 +213,13 @@ def _sid_loss_impl( ) -> Dict[str, torch.Tensor]: """Compute one ``sid_loss`` term from ``predictions``.""" loss_type = loss_cfg.WhichOneof("sid_loss") - if loss_type in _RECON_LOSSES: - x_hat, x = predictions["x_hat"], predictions["recon_target"] - if loss_type == "recon_l2_loss": - per_sample = F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - elif loss_type == "recon_l1_loss": - per_sample = F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - else: # "recon_cosine_loss" - per_sample = 1 - F.cosine_similarity(x_hat, x, dim=-1) - return {loss_type: _masked_mean(per_sample, predictions.get("recon_mask"))} + if loss_type == "recon_loss": + per_sample = self._recon_fn( + predictions["x_hat"], predictions["recon_target"] + ) + return { + "recon_loss": _masked_mean(per_sample, predictions.get("recon_mask")) + } elif loss_type == "commitment_loss": loss = self._loss_modules["commitment_loss"]( predictions["encoder_out"], predictions["latents"] diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 5b57c8b54..c9f15f7a6 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -43,10 +43,10 @@ def _make_batch( ) -def _recon_loss_cfg(kind: str = "recon_l2_loss") -> loss_pb2.LossConfig: - """A LossConfig whose sid_loss oneof is the given recon variant.""" +def _recon_loss_cfg(recon_type: str = "l2") -> loss_pb2.LossConfig: + """A LossConfig with a recon_loss term of the given recon_type.""" lc = loss_pb2.LossConfig() - getattr(lc, kind).SetInParent() + lc.recon_loss.recon_type = recon_type return lc @@ -76,7 +76,7 @@ def _create_model( input_dim=32, embed_dim=8, n_layers=2, - recon="recon_l2_loss", + recon="l2", ): """Helper to create a SidRqvae model with config-driven losses.""" n_embed_list = [16] * n_layers @@ -139,7 +139,7 @@ def test_rqvae_train_mode(self) -> None: # loss() computes the configured recon + commitment terms. losses = model.loss(predictions, batch) - self.assertIn("recon_l2_loss", losses) + self.assertIn("recon_loss", losses) self.assertIn("commitment_loss", losses) total_loss = sum(losses.values()) @@ -194,7 +194,7 @@ def test_rqvae_clip_mode(self) -> None: self.assertEqual(predictions["codes"].shape[0], B) losses = model.loss(predictions, batch) - self.assertIn("recon_l2_loss", losses) + self.assertIn("recon_loss", losses) self.assertIn("commitment_loss", losses) self.assertIn("sid_clip_loss", losses) @@ -216,7 +216,7 @@ def test_rqvae_clip_all_recon(self) -> None: batch = self._clip_batch(B, input_dim, torch.zeros(B, 1)) losses = model.loss(model.predict(batch), batch) self.assertEqual(losses["sid_clip_loss"].item(), 0.0) - self.assertGreater(losses["recon_l2_loss"].item(), 0.0) + self.assertGreater(losses["recon_loss"].item(), 0.0) def test_rqvae_clip_all_clip(self) -> None: """Mixed mode with all-clip batch: recon term 0, clip term > 0.""" @@ -227,7 +227,7 @@ def test_rqvae_clip_all_clip(self) -> None: batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) - self.assertEqual(losses["recon_l2_loss"].item(), 0.0) + self.assertEqual(losses["recon_loss"].item(), 0.0) self.assertGreater(losses["sid_clip_loss"].item(), 0.0) def test_rqvae_backward(self) -> None: @@ -283,7 +283,7 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: labels={}, ) losses = model.loss(model.predict(batch), batch) - self.assertEqual(losses["recon_l2_loss"].item(), 0.0) + self.assertEqual(losses["recon_loss"].item(), 0.0) self.assertGreater(losses["sid_clip_loss"].item(), 0.0) @parameterized.expand( @@ -312,25 +312,19 @@ def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: for layer in model._quantizer.layers: self.assertEqual(layer.use_sinkhorn, expect_use_sinkhorn) - @parameterized.expand( - [ - ("recon_l2_loss",), - ("recon_l1_loss",), - ("recon_cosine_loss",), - ] - ) - def test_recon_loss_variant_branch(self, recon) -> None: - """Each recon variant runs end-to-end (grad flows through the decoder).""" + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_recon_type_branch(self, recon_type) -> None: + """Each recon_type runs end-to-end (grad flows through the decoder).""" B, input_dim = 4, 32 - model = self._create_model(input_dim=input_dim, recon=recon) + model = self._create_model(input_dim=input_dim, recon=recon_type) model.train() model.init_loss() losses = model.loss( model.predict(_make_batch(B, input_dim)), _make_batch(B, input_dim) ) - recon_loss = losses[recon] - self.assertTrue(torch.isfinite(recon_loss), f"{recon} not finite") - recon_loss.backward() # grad must flow through the decoder + recon = losses["recon_loss"] + self.assertTrue(torch.isfinite(recon), f"{recon_type} not finite") + recon.backward() # grad must flow through the decoder def test_logit_scale_clamped_prevents_overflow(self) -> None: """A raw logit_scale far above ln(100) must not overflow. diff --git a/tzrec/protos/loss.proto b/tzrec/protos/loss.proto index b60fcc668..5539b3c67 100644 --- a/tzrec/protos/loss.proto +++ b/tzrec/protos/loss.proto @@ -13,22 +13,16 @@ message LossConfig { // lists one LossConfig per term it trains on (a reconstruction loss, the // commitment loss, and optionally the CLIP contrastive loss). oneof sid_loss { - ReconL2Loss recon_l2_loss = 6; - ReconL1Loss recon_l1_loss = 7; - ReconCosineLoss recon_cosine_loss = 8; - CommitmentLoss commitment_loss = 9; - SidClipLoss sid_clip_loss = 10; + ReconLoss recon_loss = 6; + CommitmentLoss commitment_loss = 7; + SidClipLoss sid_clip_loss = 8; } } -// RQ-VAE reconstruction losses (input vs. decoder output). -message ReconL2Loss { -} - -message ReconL1Loss { -} - -message ReconCosineLoss { +// RQ-VAE reconstruction loss (input vs. decoder output). +message ReconLoss { + // Distance for the reconstruction term: "l2" (mse), "l1" or "cos". + optional string recon_type = 1 [default = "l2"]; } // RQ-VAE commitment loss between the encoder output and the per-layer diff --git a/tzrec/tests/configs/sid_rqvae_mock.config b/tzrec/tests/configs/sid_rqvae_mock.config index f94179621..e4efd7b23 100644 --- a/tzrec/tests/configs/sid_rqvae_mock.config +++ b/tzrec/tests/configs/sid_rqvae_mock.config @@ -52,7 +52,8 @@ model_config { embedding_feature_name: "item_emb" } losses { - recon_l2_loss { + recon_loss { + recon_type: "l2" } } losses { From 105abe33439642ad0b2e43c564f1631765b56bc6 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 07:49:01 +0000 Subject: [PATCH 108/129] [feat] SID: consume framework EmbeddingGroup/build_input (drop _extract_feature) Address review (round 5): use the standard build_input / EmbeddingGroup path instead of reading a single dense feature out of Batch.dense_features, so a SID model can take multiple content embeddings + side-info in one feature group (FORGE/PLUM motivation). - BaseSidModel: add init_input/build_input + self.embedding_group (called from __init__, as in every RankModel-based model); derive _input_dim from group_total_dim(feature_group); remove _extract_feature. - proto: drop input_dim (derived); embedding_feature_name -> feature_group (default "deep"); ClipConfig {clip,is_clip_pair}_feature_name -> {clip,clip_pair}_feature_group. - sid_rqvae/sid_rqkmeans: predict via build_input; CLIP dual path reads the paired + pair-flag groups. - mock configs: drop input_dim/embedding_feature_name (feature_groups now load-bearing). Unit tests reworked to real create_features + feature_groups. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 60 +++++++------ tzrec/models/sid_rqkmeans.py | 2 +- tzrec/models/sid_rqkmeans_test.py | 47 ++++++---- tzrec/models/sid_rqvae.py | 31 ++++--- tzrec/models/sid_rqvae_test.py | 90 +++++++++++++------- tzrec/protos/models/sid_model.proto | 45 +++++----- tzrec/tests/configs/sid_rqkmeans_mock.config | 2 - tzrec/tests/configs/sid_rqvae_mock.config | 2 - 8 files changed, 169 insertions(+), 110 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index f77e3bc49..4c8497ed2 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -19,13 +19,14 @@ import torchmetrics from torch import nn -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.loss.commitment_loss import CommitmentLoss from tzrec.loss.infonce_loss import MaskedInfoNCELoss from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel +from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.utils import div_no_nan from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig @@ -82,11 +83,14 @@ class BaseSidModel(BaseModel): Factors the structure common to :class:`SidRqvae` (RQ-VAE) and :class:`SidRqkmeans` (residual K-Means): - - the shared config fields every SID proto carries — - ``embedding_feature_name`` (``_embedding_feature_name``), ``input_dim`` - (``_input_dim``), ``normalize_residuals`` (``_normalize_residuals``), + - the shared config fields every SID proto carries — ``feature_group`` + (``_feature_group``), ``normalize_residuals`` (``_normalize_residuals``), and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``), - - reading the item-embedding feature out of ``Batch.dense_features``, + - building the main input through the framework's :class:`EmbeddingGroup` + (:meth:`init_input` / :meth:`build_input`), so a SID model consumes the + same grouped/concatenated feature tensor as every other model and + ``_input_dim`` is *derived* from the group's total dimension (supporting + multiple content embeddings + side-info in one group), - the eval metrics every SID model reports — reconstruction ``mse`` and ``unique_sid_ratio`` (mean per-batch unique-SID ratio, a diversity proxy). @@ -116,39 +120,41 @@ def __init__( cfg = self._model_config # Config fields shared by every SID model (present on each SID proto - # message): the item-embedding feature, the input dimension, the - # residual-normalization toggle, and the per-layer codebook. - self._embedding_feature_name = cfg.embedding_feature_name - self._input_dim = cfg.input_dim + # message): the main input feature group, the residual-normalization + # toggle, and the per-layer codebook. + self._feature_group = cfg.feature_group self._normalize_residuals = cfg.normalize_residuals if not cfg.codebook: raise ValueError("codebook must be set, e.g. [256, 256, 256]") self._n_embed_list = list(cfg.codebook) - # Fail fast: a zero codebook entry / input_dim==0 only errors opaquely - # deep inside faiss, after the whole training pass. + # Fail fast: a zero codebook entry only errors opaquely deep inside + # faiss, after the whole training pass. if any(k < 1 for k in self._n_embed_list): raise ValueError( f"every codebook entry must be >= 1, got {self._n_embed_list}" ) - if self._input_dim < 1: - raise ValueError(f"input_dim must be >= 1, got {self._input_dim}") self._n_layers = len(self._n_embed_list) - def _extract_feature( - self, batch: Batch, feature_name: Optional[str] = None - ) -> torch.Tensor: - """Extract a named dense feature from ``Batch.dense_features``. + # Build the framework's EmbeddingGroup (same path every model uses) and + # derive the encoder input dim from the main group's total dimension — + # the group may hold one content embedding or several content + side-info + # features, all concatenated into ``_input_dim``. + self.init_input() + self._input_dim = self.embedding_group.group_total_dim(self._feature_group) + if self._input_dim < 1: + raise ValueError( + f"feature group {self._feature_group!r} has total dim " + f"{self._input_dim}; it must be >= 1" + ) + + def init_input(self) -> None: + """Build the :class:`EmbeddingGroup` from features + feature groups.""" + self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) - Args: - batch (Batch): input batch data. - feature_name (str, optional): feature name to extract. - Defaults to ``self._embedding_feature_name``. - """ - if feature_name is None: - feature_name = self._embedding_feature_name - kt = batch.dense_features[BASE_DATA_GROUP] - return kt[feature_name] + def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Build grouped input features: ``{group_name: (B, group_total_dim)}``.""" + return self.embedding_group(batch) def init_loss(self) -> None: """Initialize SID loss modules from ``ModelConfig.losses``. @@ -285,7 +291,7 @@ def update_metric( if "x_hat" not in predictions: return recon = predictions["x_hat"] - embedding = self._extract_feature(batch) + embedding = self.build_input(batch)[self._feature_group] self._metric_modules["mse"].update(recon, embedding) self._metric_modules["rel_loss"].update(recon, embedding) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 59b05af41..716651089 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -121,7 +121,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: Return: predictions (dict): a dict of predicted result. """ - embedding = self._extract_feature(batch) + embedding = self.build_input(batch)[self._feature_group] # Training: reservoir-sample only; codes are dummy until the fit. if self.is_train: diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 0b68fefa6..767662ce4 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -17,12 +17,37 @@ from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import create_features from tzrec.models.sid_rqkmeans import SidRqkmeans -from tzrec.protos import model_pb2 +from tzrec.protos import feature_pb2, model_pb2 from tzrec.protos.models import sid_model_pb2 from tzrec.utils.state_dict_util import init_parameters +def _features_and_groups(input_dim: int): + """Real ``item_emb`` raw feature + the ``deep`` group it feeds. + + SID models consume the framework's EmbeddingGroup (built from these), and + derive the K-Means dimension from the ``deep`` group's total dim — so real + features + feature_groups are required, as in every other model test. + """ + feature_cfgs = [ + feature_pb2.FeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="item_emb", value_dim=input_dim + ) + ) + ] + groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["item_emb"], + group_type=model_pb2.FeatureGroupType.DEEP, + ) + ] + return create_features(feature_cfgs), groups + + def _batch_from_rows(rows: torch.Tensor) -> Batch: """Wrap explicit ``item_emb`` rows in a minimal Batch.""" dense_feature = KeyedTensor.from_tensor_list(keys=["item_emb"], tensors=[rows]) @@ -58,26 +83,23 @@ def _create_model( normalize_residuals=False, train_sample_size=0, ): - """Build a SidRqkmeans on CPU with params initialized. - - SID models read the item-embedding dense feature directly from the - batch and do not consume feature_groups, so none is set. - """ + """Build a SidRqkmeans on CPU with params initialized.""" n_embed_list = codebook if codebook is not None else [16] * n_layers faiss_kwargs = sid_model_pb2.FaissKmeansConfig( niter=niter, verbose=False, seed=1234 ) cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, codebook=n_embed_list, normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", train_sample_size=train_sample_size, ) + features, feature_groups = _features_and_groups(input_dim) model = SidRqkmeans( - model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), - features=[], + model_config=model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqkmeans=cfg + ), + features=features, labels=[], ) init_parameters(model, device=torch.device("cpu")) @@ -119,11 +141,6 @@ def test_init_raises_on_zero_codebook_entry(self) -> None: with self.assertRaisesRegex(ValueError, "codebook entry must be >= 1"): self._create_model(codebook=[16, 0]) - def test_init_raises_on_zero_input_dim(self) -> None: - """input_dim < 1 fails fast at construction.""" - with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): - self._create_model(input_dim=0) - def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 0e2290a77..c9ecec5cb 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -76,16 +76,16 @@ def __init__( cfg = self._model_config # SidRqvae proto message - # The CLIP-style dual-encoder structure (which paired feature to encode, - # the dual path) is declared on the MODEL proto (`clip_config`); the - # contrastive OBJECTIVE is enabled by a `sid_clip_loss` entry in + # The CLIP-style dual-encoder structure (which paired feature group to + # encode, the dual path) is declared on the MODEL proto (`clip_config`); + # the contrastive OBJECTIVE is enabled by a `sid_clip_loss` entry in # ModelConfig.losses. The two must be set together. self._use_clip = cfg.HasField("clip_config") - self._clip_feature_name = ( - cfg.clip_config.clip_feature_name if self._use_clip else None + self._clip_feature_group = ( + cfg.clip_config.clip_feature_group if self._use_clip else None ) - self._is_clip_pair_feature_name = ( - cfg.clip_config.is_clip_pair_feature_name if self._use_clip else None + self._clip_pair_feature_group = ( + cfg.clip_config.clip_pair_feature_group if self._use_clip else None ) has_clip_obj = any( lc.WhichOneof("sid_loss") == "sid_clip_loss" @@ -165,13 +165,14 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: Return: predictions (dict): a dict of predicted result. """ - embedding = self._extract_feature(batch) + grouped = self.build_input(batch) + embedding = grouped[self._feature_group] if self._is_inference: # Codes-only path: get_codes does just the residual walk (no decode, # no commitment latents), so neither dual-path branch is needed. return {"codes": self._quantizer.get_codes(self._encode(embedding))} if self._use_clip: - return self._predict_mixed(embedding, batch) + return self._predict_mixed(grouped) return self._predict_rqvae(embedding) def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: @@ -187,16 +188,20 @@ def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: } def _predict_mixed( - self, embedding: torch.Tensor, batch: Batch + self, grouped: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP: dual path over the embedding + its paired feature. + """Mixed recon + CLIP: dual path over the main + paired feature groups. ``encoder_out`` / ``latents`` stack both paths so the commitment loss averages over them; ``recon_mask`` (= non-CLIP rows) restricts the recon loss to reconstruction-only rows. + + Args: + grouped (dict): the EmbeddingGroup output (group name -> tensor). """ - fea2 = self._extract_feature(batch, self._clip_feature_name) - is_clip_pair_raw = self._extract_feature(batch, self._is_clip_pair_feature_name) + embedding = grouped[self._feature_group] + fea2 = grouped[self._clip_feature_group] + is_clip_pair_raw = grouped[self._clip_pair_feature_group] clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 z_e1 = self._encode(embedding) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index c9f15f7a6..5eebdaa28 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -16,26 +16,47 @@ from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import create_features from tzrec.models.sid_rqvae import SidRqvae -from tzrec.protos import loss_pb2, model_pb2 +from tzrec.protos import feature_pb2, loss_pb2, model_pb2 from tzrec.protos.models import sid_model_pb2 from tzrec.utils.state_dict_util import init_parameters -def _make_batch( - batch_size: int, - input_dim: int, - feature_name: str = "item_emb", - extra_features: dict = None, -) -> Batch: - """Create a minimal Batch with dense embedding features.""" - keys = [feature_name] - tensors = [torch.randn(batch_size, input_dim)] - if extra_features: - for k, v in extra_features.items(): - keys.append(k) - tensors.append(v) - dense_feature = KeyedTensor.from_tensor_list(keys=keys, tensors=tensors) +def _features_and_groups(input_dim: int, use_clip: bool = False): + """Real raw features + feature groups for a SID model. + + Mirrors how every other model test wires inputs: ``create_features`` builds + the ``BaseFeature`` objects and ``feature_groups`` (consumed by the model's + :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for CLIP, the + paired image group and the per-row pair-flag group. + """ + + def _raw(name: str, dim: int) -> feature_pb2.FeatureConfig: + return feature_pb2.FeatureConfig( + raw_feature=feature_pb2.RawFeature(feature_name=name, value_dim=dim) + ) + + def _deep(group_name: str, feature_name: str) -> model_pb2.FeatureGroupConfig: + return model_pb2.FeatureGroupConfig( + group_name=group_name, + feature_names=[feature_name], + group_type=model_pb2.FeatureGroupType.DEEP, + ) + + feature_cfgs = [_raw("item_emb", input_dim)] + groups = [_deep("deep", "item_emb")] + if use_clip: + feature_cfgs += [_raw("image_emb", input_dim), _raw("is_clip_pair", 1)] + groups += [_deep("clip_image", "image_emb"), _deep("clip_pair", "is_clip_pair")] + return create_features(feature_cfgs), groups + + +def _make_batch(batch_size: int, input_dim: int) -> Batch: + """Create a minimal Batch with the ``item_emb`` dense feature.""" + dense_feature = KeyedTensor.from_tensor_list( + keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim)] + ) return Batch( dense_features={BASE_DATA_GROUP: dense_feature}, sparse_features={}, @@ -81,25 +102,26 @@ def _create_model( """Helper to create a SidRqvae model with config-driven losses.""" n_embed_list = [16] * n_layers sid_rqvae_cfg = sid_model_pb2.SidRqvae( - input_dim=input_dim, embed_dim=embed_dim, codebook=n_embed_list, forward_mode="ste", kmeans_init=False, - embedding_feature_name="item_emb", ) losses = [_recon_loss_cfg(recon), _commitment_cfg()] if use_clip: - # structure on the model proto; objective marker in losses. - sid_rqvae_cfg.clip_config.clip_feature_name = "image_emb" - sid_rqvae_cfg.clip_config.is_clip_pair_feature_name = "is_clip_pair" + # structure on the model proto (paired + pair-flag groups); + # objective marker in losses. + sid_rqvae_cfg.clip_config.clip_feature_group = "clip_image" + sid_rqvae_cfg.clip_config.clip_pair_feature_group = "clip_pair" losses.append(_clip_cfg()) - # SID models read the item-embedding dense feature directly from the - # batch; they do not consume feature_groups, so none is set (which - # keeps the config consistent with the empty ``features`` list). - model_config = model_pb2.ModelConfig(sid_rqvae=sid_rqvae_cfg, losses=losses) - model = SidRqvae(model_config=model_config, features=[], labels=[]) + # SID models consume the framework's EmbeddingGroup: input_dim is derived + # from the ``deep`` group, so real features + feature_groups are required. + features, feature_groups = _features_and_groups(input_dim, use_clip) + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, losses=losses + ) + model = SidRqvae(model_config=model_config, features=features, labels=[]) init_parameters(model, device=torch.device("cpu")) return model @@ -248,14 +270,17 @@ def test_rqvae_backward(self) -> None: def test_commitment_latent_weight_wrong_length_raises(self) -> None: """A commitment_loss with a bad latent_weight length fails in init_loss.""" + features, feature_groups = _features_and_groups(32) for bad in ([1.0], [1.0, 0.5, 0.25]): cfg = sid_model_pb2.SidRqvae( - input_dim=32, embed_dim=8, codebook=[16, 16], kmeans_init=False + embed_dim=8, codebook=[16, 16], kmeans_init=False ) model_config = model_pb2.ModelConfig( - sid_rqvae=cfg, losses=[_commitment_cfg(latent_weight=bad)] + feature_groups=feature_groups, + sid_rqvae=cfg, + losses=[_commitment_cfg(latent_weight=bad)], ) - model = SidRqvae(model_config=model_config, features=[], labels=[]) + model = SidRqvae(model_config=model_config, features=features, labels=[]) with self.assertRaisesRegex(ValueError, "latent_weight"): model.init_loss() @@ -296,17 +321,20 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: """``sinkhorn_config.enabled`` (or its omission) drives layer.use_sinkhorn.""" cfg = sid_model_pb2.SidRqvae( - input_dim=32, embed_dim=8, codebook=[16, 16], forward_mode="ste", kmeans_init=False, - embedding_feature_name="item_emb", ) if enabled is not None: cfg.sinkhorn_config.CopyFrom(sid_model_pb2.SinkhornConfig(enabled=enabled)) + features, feature_groups = _features_and_groups(32) model = SidRqvae( - model_config=model_pb2.ModelConfig(sid_rqvae=cfg), features=[], labels=[] + model_config=model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg + ), + features=features, + labels=[], ) init_parameters(model, device=torch.device("cpu")) for layer in model._quantizer.layers: diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index f3e02436c..fd5b4d9b5 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -26,20 +26,22 @@ message SinkhornConfig { optional bool enabled = 3 [default = true]; } -// CLIP-style dual-encoder wiring for SidRqvae: which paired feature to encode -// and which column flags the contrastive-pair rows. This is model structure / -// input contract (declared on the model), not loss config. +// CLIP-style dual-encoder wiring for SidRqvae: which paired feature group to +// encode and which group flags the contrastive-pair rows. This is model +// structure / input contract (declared on the model), not loss config. message ClipConfig { - // Name of the second (paired) embedding feature inside the input Batch. - required string clip_feature_name = 1; - // Name of the per-row float feature flagging pair rows (>0.5 = pair). - required string is_clip_pair_feature_name = 2; + // Name of the second (paired) embedding FEATURE GROUP (built by the same + // EmbeddingGroup as the main input; same total dim as `feature_group`). + required string clip_feature_group = 1; + // Name of the per-row pair-flag FEATURE GROUP (a single raw feature, + // dim 1; >0.5 = contrastive pair). + required string clip_pair_feature_group = 2; } message SidRqvae { // === Network structure === - // Input embedding dimension. - optional uint32 input_dim = 1 [default = 512]; + // (input_dim is not configured here — it is derived from the total + // dimension of the `feature_group` built by the model's EmbeddingGroup.) // Quantization latent dimension (encoder output / codebook dim). optional uint32 embed_dim = 2 [default = 64]; // Encoder hidden layer sizes, e.g. [256, 128]. @@ -78,23 +80,26 @@ message SidRqvae { // to override params; set ``enabled: false`` inside it to disable. optional SinkhornConfig sinkhorn_config = 15; // CLIP-style dual-encoder structure: when set, the model encodes a second - // (paired) feature and runs the contrastive path. This declares the model's - // input contract + topology; the contrastive OBJECTIVE is enabled separately - // by a `sid_clip_loss` entry in ModelConfig.losses (both must be set together). + // (paired) feature group and runs the contrastive path. This declares the + // model's input contract + topology; the contrastive OBJECTIVE is enabled + // separately by a `sid_clip_loss` entry in ModelConfig.losses (both must be + // set together). optional ClipConfig clip_config = 16; // Reconstruction, commitment and (optional) contrastive losses are configured // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the CLIP // feature wiring above lives on this message. - // Name of the item embedding feature inside the input Batch. - optional string embedding_feature_name = 40 [default = "item_emb"]; + // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup + // from ModelConfig.feature_groups). May hold one or many content/side-info + // features; their concatenated dim is the encoder input_dim. + optional string feature_group = 40 [default = "deep"]; } message SidRqkmeans { - // Input embedding dimension (K-Means runs directly on raw embeddings, - // no encoder). - optional uint32 input_dim = 1 [default = 512]; + // (input_dim is not configured here — K-Means runs directly on the raw + // embeddings of the `feature_group` built by the model's EmbeddingGroup, + // whose total dim is the K-Means dimension.) // Per-layer cluster counts, e.g. [256, 256, 256]. // List length is the number of residual quantization layers. Entries // may differ per layer (non-uniform codebooks such as [256, 512, 1024] @@ -113,6 +118,8 @@ message SidRqkmeans { // are wasted. optional uint32 train_sample_size = 6 [default = 0]; - // Name of the item embedding feature inside the input Batch. - optional string embedding_feature_name = 40 [default = "item_emb"]; + // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup + // from ModelConfig.feature_groups). May hold one or many content/side-info + // features; their concatenated dim is the K-Means dimension. + optional string feature_group = 40 [default = "deep"]; } diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index 0e6dec907..cd95dac8d 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -42,12 +42,10 @@ model_config { group_type: DEEP } sid_rqkmeans { - input_dim: 16 codebook: 16 codebook: 16 codebook: 16 normalize_residuals: false - embedding_feature_name: "item_emb" faiss_kmeans_kwargs { niter: 5 seed: 42 diff --git a/tzrec/tests/configs/sid_rqvae_mock.config b/tzrec/tests/configs/sid_rqvae_mock.config index e4efd7b23..6f4691136 100644 --- a/tzrec/tests/configs/sid_rqvae_mock.config +++ b/tzrec/tests/configs/sid_rqvae_mock.config @@ -41,7 +41,6 @@ model_config { group_type: DEEP } sid_rqvae { - input_dim: 16 embed_dim: 8 hidden_dims: 16 codebook: 16 @@ -49,7 +48,6 @@ model_config { codebook: 16 forward_mode: "ste" kmeans_init: false - embedding_feature_name: "item_emb" } losses { recon_loss { From e11faab5d78ba59b62871915f6fd946db58d8891 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 08:21:16 +0000 Subject: [PATCH 109/129] [feat] SID CLIP: fail-fast dim guard + mock config; trim comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - SidRqvae: validate at init that clip_feature_group's total dim equals the main feature_group dim (both share one encoder) — else fail fast instead of an opaque matmul shape error on the first contrastive forward; add a test. - add tzrec/tests/configs/sid_rqvae_clip_mock.config exercising the full CLIP path (deep + clip_image + clip_pair groups, clip_config, sid_clip_loss). - trim verbose/narrative comments to code-focused one-liners. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 13 +-- tzrec/models/sid_rqvae.py | 17 +++- tzrec/models/sid_rqvae_test.py | 33 +++++-- .../tests/configs/sid_rqvae_clip_mock.config | 97 +++++++++++++++++++ 4 files changed, 140 insertions(+), 20 deletions(-) create mode 100644 tzrec/tests/configs/sid_rqvae_clip_mock.config diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 4c8497ed2..1f0a0b4a5 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -119,27 +119,22 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config - # Config fields shared by every SID model (present on each SID proto - # message): the main input feature group, the residual-normalization - # toggle, and the per-layer codebook. + # Config fields shared by every SID proto message. self._feature_group = cfg.feature_group self._normalize_residuals = cfg.normalize_residuals if not cfg.codebook: raise ValueError("codebook must be set, e.g. [256, 256, 256]") self._n_embed_list = list(cfg.codebook) - # Fail fast: a zero codebook entry only errors opaquely deep inside - # faiss, after the whole training pass. + # Fail fast: a zero entry only errors opaquely deep in faiss later. if any(k < 1 for k in self._n_embed_list): raise ValueError( f"every codebook entry must be >= 1, got {self._n_embed_list}" ) self._n_layers = len(self._n_embed_list) - # Build the framework's EmbeddingGroup (same path every model uses) and - # derive the encoder input dim from the main group's total dimension — - # the group may hold one content embedding or several content + side-info - # features, all concatenated into ``_input_dim``. + # Derive the encoder input dim from the main group's total dim (it may + # concatenate several content + side-info features). self.init_input() self._input_dim = self.embedding_group.group_total_dim(self._feature_group) if self._input_dim < 1: diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index c9ecec5cb..46b84540a 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -76,10 +76,8 @@ def __init__( cfg = self._model_config # SidRqvae proto message - # The CLIP-style dual-encoder structure (which paired feature group to - # encode, the dual path) is declared on the MODEL proto (`clip_config`); - # the contrastive OBJECTIVE is enabled by a `sid_clip_loss` entry in - # ModelConfig.losses. The two must be set together. + # Structure (clip_config) lives on the model proto; the objective + # (sid_clip_loss) lives in losses. The two must be set together. self._use_clip = cfg.HasField("clip_config") self._clip_feature_group = ( cfg.clip_config.clip_feature_group if self._use_clip else None @@ -97,6 +95,17 @@ def __init__( "losses (the objective) must be set together; got " f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" ) + # The paired group shares the main encoder, so it must match input_dim; + # fail fast here instead of an opaque matmul error on the first forward. + if self._use_clip: + clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) + if clip_dim != self._input_dim: + raise ValueError( + f"clip_feature_group {self._clip_feature_group!r} has total " + f"dim {clip_dim}, but it is encoded by the same encoder as " + f"the main feature_group (dim {self._input_dim}); the two " + "must match" + ) embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 5eebdaa28..f18328fe5 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -23,13 +23,15 @@ from tzrec.utils.state_dict_util import init_parameters -def _features_and_groups(input_dim: int, use_clip: bool = False): +def _features_and_groups(input_dim: int, use_clip: bool = False, clip_dim: int = None): """Real raw features + feature groups for a SID model. Mirrors how every other model test wires inputs: ``create_features`` builds the ``BaseFeature`` objects and ``feature_groups`` (consumed by the model's :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for CLIP, the - paired image group and the per-row pair-flag group. + paired image group and the per-row pair-flag group. ``clip_dim`` (default: + match ``input_dim``) sizes the paired group, so a test can deliberately + mismatch it. """ def _raw(name: str, dim: int) -> feature_pb2.FeatureConfig: @@ -47,7 +49,10 @@ def _deep(group_name: str, feature_name: str) -> model_pb2.FeatureGroupConfig: feature_cfgs = [_raw("item_emb", input_dim)] groups = [_deep("deep", "item_emb")] if use_clip: - feature_cfgs += [_raw("image_emb", input_dim), _raw("is_clip_pair", 1)] + feature_cfgs += [ + _raw("image_emb", clip_dim if clip_dim is not None else input_dim), + _raw("is_clip_pair", 1), + ] groups += [_deep("clip_image", "image_emb"), _deep("clip_pair", "is_clip_pair")] return create_features(feature_cfgs), groups @@ -109,14 +114,11 @@ def _create_model( ) losses = [_recon_loss_cfg(recon), _commitment_cfg()] if use_clip: - # structure on the model proto (paired + pair-flag groups); - # objective marker in losses. sid_rqvae_cfg.clip_config.clip_feature_group = "clip_image" sid_rqvae_cfg.clip_config.clip_pair_feature_group = "clip_pair" losses.append(_clip_cfg()) - # SID models consume the framework's EmbeddingGroup: input_dim is derived - # from the ``deep`` group, so real features + feature_groups are required. + # Real features + feature_groups: input_dim is derived from the group. features, feature_groups = _features_and_groups(input_dim, use_clip) model_config = model_pb2.ModelConfig( feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, losses=losses @@ -284,6 +286,23 @@ def test_commitment_latent_weight_wrong_length_raises(self) -> None: with self.assertRaisesRegex(ValueError, "latent_weight"): model.init_loss() + def test_clip_feature_group_dim_mismatch_raises(self) -> None: + """A CLIP paired group whose dim != the main group fails fast at init. + + The paired feature is encoded by the same encoder as the main input, so + a dim mismatch would otherwise crash with an opaque matmul shape error + on the first contrastive forward — not at construction. + """ + features, feature_groups = _features_and_groups(32, use_clip=True, clip_dim=16) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.clip_config.clip_feature_group = "clip_image" + cfg.clip_config.clip_pair_feature_group = "clip_pair" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + ) + with self.assertRaisesRegex(ValueError, "must match"): + SidRqvae(model_config=model_config, features=features, labels=[]) + def test_clip_mask_uses_flag_not_equality(self) -> None: """The is_clip_pair flag, not bit-exact equality, drives routing. diff --git a/tzrec/tests/configs/sid_rqvae_clip_mock.config b/tzrec/tests/configs/sid_rqvae_clip_mock.config new file mode 100644 index 000000000..9a1cf9fa1 --- /dev/null +++ b/tzrec/tests/configs/sid_rqvae_clip_mock.config @@ -0,0 +1,97 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqvae_clip_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 2 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "image_emb" + expression: "item:image_embedding" + value_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "is_clip_pair" + expression: "item:is_clip_pair" + value_dim: 1 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + feature_groups { + group_name: "clip_image" + feature_names: "image_emb" + group_type: DEEP + } + feature_groups { + group_name: "clip_pair" + feature_names: "is_clip_pair" + group_type: DEEP + } + sid_rqvae { + embed_dim: 8 + hidden_dims: 16 + codebook: 16 + codebook: 16 + codebook: 16 + forward_mode: "ste" + kmeans_init: false + # clip_image shares the encoder with "deep", so it must match its dim; + # clip_pair flags rows (>0.5 = pair). Objective: sid_clip_loss below. + clip_config { + clip_feature_group: "clip_image" + clip_pair_feature_group: "clip_pair" + } + } + losses { + recon_loss { + recon_type: "l2" + } + } + losses { + commitment_loss { + latent_weight: 1.0 + latent_weight: 0.5 + } + } + losses { + sid_clip_loss { + } + } +} From e8c72fd0a9dcfbb0eab40da223b51b3bd56603af Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 09:16:35 +0000 Subject: [PATCH 110/129] =?UTF-8?q?[refactor]=20SID:=20review=20cleanup=20?= =?UTF-8?q?=E2=80=94=20merge=20RQ-VAE=20pass,=20fix=20eval=20metric=20+=20?= =?UTF-8?q?CLIP=20guards?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Post-E2 review/cleanup of the SID models: - efficiency: update_metric reads predictions["recon_target"] instead of re- running build_input per eval step; SidRqkmeans exposes recon_target alongside x_hat in fitted-eval predictions. - merge: extract SidRqvae._rqvae_pass (encode->quantize->decode), used by _predict_rqvae once and _predict_mixed twice (was triplicated). - bug: CLIP eval mse/rel_loss now respect recon_mask, so they score the same (non-pair) rows the recon loss optimizes instead of all rows. - fail-fast: validate the main feature_group (has_group) and the CLIP groups exist, the paired group matches input_dim, and the pair-flag group is dim-1 — instead of opaque KeyError/matmul errors on the first forward. - docs: correct the update_train_metric + sid_integration_test docstrings that the EmbeddingGroup refactor made stale. - tests: cover the CLIP group guards + metric masking; restore the derived input_dim<1 guard test (via a 0-dim group). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 36 ++++++++++++++----- tzrec/models/sid_rqkmeans.py | 8 +++-- tzrec/models/sid_rqkmeans_test.py | 19 +++++++--- tzrec/models/sid_rqvae.py | 55 ++++++++++++++++++++-------- tzrec/models/sid_rqvae_test.py | 56 ++++++++++++++++++++++++++--- tzrec/tests/sid_integration_test.py | 5 +-- 6 files changed, 142 insertions(+), 37 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 1f0a0b4a5..06d0dc0cd 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -133,9 +133,16 @@ def __init__( ) self._n_layers = len(self._n_embed_list) - # Derive the encoder input dim from the main group's total dim (it may - # concatenate several content + side-info features). + # Built in the base __init__ (not the subclass like Rank/Match models) + # so _input_dim is ready before the subclass builds its encoder; derived + # from the main group's total dim (which may concatenate several + # content + side-info features). self.init_input() + if not self.embedding_group.has_group(self._feature_group): + raise ValueError( + f"feature_group {self._feature_group!r} is not in " + f"model_config.feature_groups {self.embedding_group.group_names()}" + ) self._input_dim = self.embedding_group.group_total_dim(self._feature_group) if self._input_dim < 1: raise ValueError( @@ -270,13 +277,17 @@ def update_metric( batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - """Update eval metrics from the reconstruction + the re-extracted input. + """Update eval metrics from the reconstruction vs. the input embedding. ``predictions["x_hat"]`` is the model's reconstruction of the input embedding (the centroid sum for RQ-KMeans, the decoder output for - RQ-VAE). Subclasses expose it only when it is meaningful, so a - not-yet-fitted model omits it and this logs nothing. The target - embedding is re-extracted from ``batch`` (it is an input, not an output). + RQ-VAE); ``predictions["recon_target"]`` is the input it reconstructs. + Subclasses expose both only when meaningful, so a not-yet-fitted model + omits them and this logs nothing. (Reading the target from + ``predictions`` avoids a second ``build_input`` pass over ``batch``.) + For the mixed CLIP path the reconstruction is scored only on the + non-pair rows (``recon_mask``), matching the masked training recon loss + so the eval mse/rel_loss stay comparable to the optimized objective. Args: predictions (dict): a dict of predicted result. @@ -286,7 +297,13 @@ def update_metric( if "x_hat" not in predictions: return recon = predictions["x_hat"] - embedding = self.build_input(batch)[self._feature_group] + embedding = predictions["recon_target"] + # Restrict reconstruction scoring to the rows the recon loss optimizes + # (the mixed CLIP path masks out pair rows); no mask = score all rows. + recon_mask = predictions.get("recon_mask") + if recon_mask is not None: + recon = recon[recon_mask] + embedding = embedding[recon_mask] self._metric_modules["mse"].update(recon, embedding) self._metric_modules["rel_loss"].update(recon, embedding) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) @@ -298,7 +315,8 @@ def update_train_metric( ) -> None: """Update train-path metric state. - Default is a no-op: K-Means has no train-time codes, so only models - with a meaningful train signal (RQ-VAE) override this. + Default no-op: the current SID models report metrics at eval (after the + codebook is fit / the decoder is trained), not during training. A + subclass with a meaningful train-time signal may override this. """ return diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 716651089..2057680d7 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -139,11 +139,13 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - # Expose the centroid-sum reconstruction (``x_hat``) for update_metric - # only once fitted — pre-fit it is all-zeros, so omitting it skips the - # eval metrics. (Meaningful only with normalize_residuals=False.) + # Expose the centroid-sum reconstruction (``x_hat``) + its target for + # update_metric only once fitted — pre-fit x_hat is all-zeros, so + # omitting it skips the eval metrics. (Meaningful only with + # normalize_residuals=False.) if self.is_eval and self._quantizer.is_fitted: predictions["x_hat"] = quantized + predictions["recon_target"] = embedding return predictions diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 767662ce4..86cbcc5c8 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -141,6 +141,15 @@ def test_init_raises_on_zero_codebook_entry(self) -> None: with self.assertRaisesRegex(ValueError, "codebook entry must be >= 1"): self._create_model(codebook=[16, 0]) + def test_init_raises_on_zero_dim_feature_group(self) -> None: + """A feature group with total dim 0 fails fast (derived input_dim < 1). + + input_dim is no longer a config knob — it is derived from the group, so + the guard now fires via a 0-dim group rather than an explicit input_dim=0. + """ + with self.assertRaisesRegex(ValueError, "must be >= 1"): + self._create_model(input_dim=0) + def test_predict_collects_buffer(self) -> None: """In train mode, predict reservoir-samples; never fits.""" B, input_dim = 8, 32 @@ -255,7 +264,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval (post-fit) exposes codes + x_hat; inference is codes-only.""" + """Eval (post-fit) exposes codes + x_hat + recon_target; infer codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -268,12 +277,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode (fitted): the reconstruction is exposed as ``x_hat`` for - # update_metric; the input embedding is re-extracted from the batch - # there, not threaded through predictions. + # Eval mode (fitted): the reconstruction (``x_hat``) and its target + # (``recon_target``) are both exposed for update_metric, so it scores + # without a second build_input pass over the batch. model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat"}) + self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat", "recon_target"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 46b84540a..6c742da40 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -19,7 +19,7 @@ on the model — there is no intermediate ``RQVAE`` module wrapper. """ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import torch from torch import nn @@ -30,6 +30,7 @@ from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) +from tzrec.modules.sid.types import ResidualQuantizerOutput from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils.config_util import config_to_kwargs from tzrec.utils.logging_util import logger @@ -95,9 +96,18 @@ def __init__( "losses (the objective) must be set together; got " f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" ) - # The paired group shares the main encoder, so it must match input_dim; - # fail fast here instead of an opaque matmul error on the first forward. + # Validate the CLIP groups up front (parity with the base feature_group + # has_group guard): a typo/missing group otherwise only KeyErrors on the + # first forward, after the whole TorchRec setup. if self._use_clip: + for grp in (self._clip_feature_group, self._clip_pair_feature_group): + if not self.embedding_group.has_group(grp): + raise ValueError( + f"clip group {grp!r} is not in model_config.feature_groups " + f"{self.embedding_group.group_names()}" + ) + # The paired group shares the main encoder, so it must match + # input_dim; fail fast instead of an opaque matmul error. clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) if clip_dim != self._input_dim: raise ValueError( @@ -106,6 +116,16 @@ def __init__( f"the main feature_group (dim {self._input_dim}); the two " "must match" ) + # The pair flag is read as a single raw column (>0.5); a transformed + # or multi-dim group would silently mis-route rows. + pair_dim = self.embedding_group.group_total_dim( + self._clip_pair_feature_group + ) + if pair_dim != 1: + raise ValueError( + f"clip_pair_feature_group {self._clip_pair_feature_group!r} " + f"must be a single dim-1 raw flag, got total dim {pair_dim}" + ) embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -184,13 +204,25 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: return self._predict_mixed(grouped) return self._predict_rqvae(embedding) - def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: - """Standard RQ-VAE: encode -> quantize -> decode.""" - z_e = self._encode(embedding) + def _rqvae_pass( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, ResidualQuantizerOutput, torch.Tensor]: + """One RQ-VAE pass over ``x``: encode -> quantize -> decode. + + Returns the encoder output ``z_e`` (commitment operand), the quantizer + output ``quant`` (cluster_ids / latents / quantized_embeddings) and the + decoded reconstruction ``x_hat``. + """ + z_e = self._encode(x) quant = self._quantizer(z_e) + return z_e, quant, self._decode(quant.quantized_embeddings) + + def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: + """Standard RQ-VAE: a single reconstruction pass.""" + z_e, quant, x_hat = self._rqvae_pass(embedding) return { "codes": quant.cluster_ids, - "x_hat": self._decode(quant.quantized_embeddings), + "x_hat": x_hat, "recon_target": embedding, "encoder_out": z_e, "latents": quant.latents, @@ -213,13 +245,8 @@ def _predict_mixed( is_clip_pair_raw = grouped[self._clip_pair_feature_group] clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 - z_e1 = self._encode(embedding) - quant1 = self._quantizer(z_e1) - x_hat1 = self._decode(quant1.quantized_embeddings) - - z_e2 = self._encode(fea2) - quant2 = self._quantizer(z_e2) - x_hat2 = self._decode(quant2.quantized_embeddings) + z_e1, quant1, x_hat1 = self._rqvae_pass(embedding) + z_e2, quant2, x_hat2 = self._rqvae_pass(fea2) return { "codes": quant1.cluster_ids, diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index f18328fe5..54e7b33cd 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -23,15 +23,17 @@ from tzrec.utils.state_dict_util import init_parameters -def _features_and_groups(input_dim: int, use_clip: bool = False, clip_dim: int = None): +def _features_and_groups( + input_dim: int, use_clip: bool = False, clip_dim: int = None, pair_dim: int = 1 +): """Real raw features + feature groups for a SID model. Mirrors how every other model test wires inputs: ``create_features`` builds the ``BaseFeature`` objects and ``feature_groups`` (consumed by the model's :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for CLIP, the paired image group and the per-row pair-flag group. ``clip_dim`` (default: - match ``input_dim``) sizes the paired group, so a test can deliberately - mismatch it. + match ``input_dim``) sizes the paired group and ``pair_dim`` (default 1) + sizes the pair-flag group, so a test can deliberately mismatch either. """ def _raw(name: str, dim: int) -> feature_pb2.FeatureConfig: @@ -51,7 +53,7 @@ def _deep(group_name: str, feature_name: str) -> model_pb2.FeatureGroupConfig: if use_clip: feature_cfgs += [ _raw("image_emb", clip_dim if clip_dim is not None else input_dim), - _raw("is_clip_pair", 1), + _raw("is_clip_pair", pair_dim), ] groups += [_deep("clip_image", "image_emb"), _deep("clip_pair", "is_clip_pair")] return create_features(feature_cfgs), groups @@ -303,6 +305,52 @@ def test_clip_feature_group_dim_mismatch_raises(self) -> None: with self.assertRaisesRegex(ValueError, "must match"): SidRqvae(model_config=model_config, features=features, labels=[]) + def test_clip_pair_group_must_be_dim_1(self) -> None: + """A pair-flag group with dim != 1 fails fast (would mis-route rows).""" + features, feature_groups = _features_and_groups(32, use_clip=True, pair_dim=3) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.clip_config.clip_feature_group = "clip_image" + cfg.clip_config.clip_pair_feature_group = "clip_pair" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + ) + with self.assertRaisesRegex(ValueError, "dim-1 raw flag"): + SidRqvae(model_config=model_config, features=features, labels=[]) + + def test_clip_group_missing_raises(self) -> None: + """A typo'd clip group name fails fast at init, not a forward KeyError.""" + features, feature_groups = _features_and_groups(32, use_clip=True) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.clip_config.clip_feature_group = "clip_image" + cfg.clip_config.clip_pair_feature_group = "clip_pairTYPO" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + ) + with self.assertRaisesRegex(ValueError, "not in model_config.feature_groups"): + SidRqvae(model_config=model_config, features=features, labels=[]) + + def test_eval_metric_masks_clip_pair_rows(self) -> None: + """CLIP eval mse/rel_loss score only the non-pair (recon) rows. + + Training masks the recon loss to non-pair rows; update_metric must apply + the same ``recon_mask`` so the eval metric stays comparable (pair rows, + which the decoder is not trained to reconstruct, must not dilute it). + """ + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_clip=True) + model.eval() + model.init_metric() + + # All-pair batch: recon_mask selects zero rows, so mse observes none. + all_pair = self._clip_batch(B, input_dim, torch.ones(B, 1)) + model.update_metric(model.predict(all_pair), all_pair) + self.assertEqual(model._metric_modules["mse"].total.item(), 0.0) + + # A recon (non-pair) batch then contributes rows. + all_recon = self._clip_batch(B, input_dim, torch.zeros(B, 1)) + model.update_metric(model.predict(all_recon), all_recon) + self.assertGreater(model._metric_modules["mse"].total.item(), 0.0) + def test_clip_mask_uses_flag_not_equality(self) -> None: """The is_clip_pair flag, not bit-exact equality, drives routing. diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 27a042085..4107c4b0f 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -52,8 +52,9 @@ def _prepare_config( ) -> str: """Write an embedding parquet + a SID config pointed at it. - Single dense ``embedding`` column, no labels — SID reads the item - embedding straight from the batch. Returns the saved config path. + Single dense ``embedding`` column, no labels — the config's FG maps it + to the ``item_emb`` feature, which its ``deep`` feature_group feeds to + the model's EmbeddingGroup. Returns the saved config path. """ data_dir = os.path.join(self.test_dir, "sid_data") os.makedirs(data_dir, exist_ok=True) From 02b0b28ba4c3c25ca1acec41bf3fff2c0c337191 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 09:22:08 +0000 Subject: [PATCH 111/129] [chore] SID: drop comments that just restate their error messages/docstring Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_model.py | 2 -- tzrec/models/sid_rqvae.py | 9 --------- 2 files changed, 11 deletions(-) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 06d0dc0cd..d8dc3df19 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -298,8 +298,6 @@ def update_metric( return recon = predictions["x_hat"] embedding = predictions["recon_target"] - # Restrict reconstruction scoring to the rows the recon loss optimizes - # (the mixed CLIP path masks out pair rows); no mask = score all rows. recon_mask = predictions.get("recon_mask") if recon_mask is not None: recon = recon[recon_mask] diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 6c742da40..44830500f 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -77,8 +77,6 @@ def __init__( cfg = self._model_config # SidRqvae proto message - # Structure (clip_config) lives on the model proto; the objective - # (sid_clip_loss) lives in losses. The two must be set together. self._use_clip = cfg.HasField("clip_config") self._clip_feature_group = ( cfg.clip_config.clip_feature_group if self._use_clip else None @@ -96,9 +94,6 @@ def __init__( "losses (the objective) must be set together; got " f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" ) - # Validate the CLIP groups up front (parity with the base feature_group - # has_group guard): a typo/missing group otherwise only KeyErrors on the - # first forward, after the whole TorchRec setup. if self._use_clip: for grp in (self._clip_feature_group, self._clip_pair_feature_group): if not self.embedding_group.has_group(grp): @@ -106,8 +101,6 @@ def __init__( f"clip group {grp!r} is not in model_config.feature_groups " f"{self.embedding_group.group_names()}" ) - # The paired group shares the main encoder, so it must match - # input_dim; fail fast instead of an opaque matmul error. clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) if clip_dim != self._input_dim: raise ValueError( @@ -116,8 +109,6 @@ def __init__( f"the main feature_group (dim {self._input_dim}); the two " "must match" ) - # The pair flag is read as a single raw column (>0.5); a transformed - # or multi-dim group would silently mis-route rows. pair_dim = self.embedding_group.group_total_dim( self._clip_pair_feature_group ) From 2276c9de94d9f6f17929e634e63f8d60709f1b80 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 11:00:59 +0000 Subject: [PATCH 112/129] [refactor] SID: land review #2 (framework MLP) + #10 (cdist) Verified on test/sid_abstract (ee3ecdc), re-applied onto the current code: - #2: encoder/decoder use the framework MLP(hidden_units) + a bare trailing nn.Linear for the unbounded latent/recon projection (MLP always activates its last layer); behavior-preserving vs the removed private _build_mlp. - #10: VectorQuantizeLayer l2 distance uses torch.cdist(x, codebook, p=2).pow(2); drop the hand-rolled _squared_euclidean_distance helper. cdist matches it to fp32 noise incl. a coincident zero-distance point (finite grad, no NaN). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 26 +++++++++++------------ tzrec/modules/sid/vector_quantize.py | 20 +---------------- tzrec/modules/sid/vector_quantize_test.py | 13 ++++++------ 3 files changed, 20 insertions(+), 39 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 44830500f..9b7520b71 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -27,6 +27,7 @@ from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.sid_model import BaseSidModel +from tzrec.modules.mlp import MLP from tzrec.modules.sid.residual_vector_quantizer import ( ResidualVectorQuantizer, ) @@ -55,16 +56,6 @@ class SidRqvae(BaseSidModel): sample_weights (list): sample weight names. """ - @staticmethod - def _build_mlp(dims: List[int]) -> nn.Sequential: - """Build MLP: dims[0] -> ... -> dims[-1], ReLU between hidden layers.""" - layers: List[nn.Module] = [] - for i in range(len(dims) - 1): - layers.append(nn.Linear(dims[i], dims[i + 1])) - if i < len(dims) - 2: # no activation after the last layer - layers.append(nn.ReLU()) - return nn.Sequential(*layers) - def __init__( self, model_config: ModelConfig, @@ -134,10 +125,17 @@ def __init__( # restates them; keys map to the quantizer's use_sinkhorn/iters/epsilon. sinkhorn_cfg = config_to_kwargs(cfg.sinkhorn_config) - self._encoder = self._build_mlp([self._input_dim, *hidden_dims, embed_dim]) - # Decoder is the symmetric reverse of the encoder. - self._decoder = self._build_mlp( - [embed_dim, *reversed(hidden_dims), self._input_dim] + # Framework MLP (Linear+ReLU per hidden) + a bare trailing Linear: the + # latent / reconstruction must be unbounded and MLP always activates its + # last layer, so the projection carries no activation. + self._encoder = nn.Sequential( + MLP(self._input_dim, hidden_units=hidden_dims), + nn.Linear(hidden_dims[-1], embed_dim), + ) + # Decoder mirrors the encoder over the reversed hidden stack. + self._decoder = nn.Sequential( + MLP(embed_dim, hidden_units=list(reversed(hidden_dims))), + nn.Linear(hidden_dims[0], self._input_dim), ) self._quantizer = ResidualVectorQuantizer( diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 4fdb56387..6326b7fd0 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -23,24 +23,6 @@ ) -def _squared_euclidean_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - """Squared L2 distance between rows of ``x`` and ``y``. - - Args: - x (Tensor): data points, shape (N, D). - y (Tensor): centroids, shape (K, D). - - Returns: - Tensor: squared distances, shape (N, K). - - Grad-enabled and branch-free (Gumbel needs grad; STE/Sinkhorn callers add - their own ``no_grad``). - """ - x_sq = x.pow(2).sum(dim=1, keepdim=True) # (N, 1) - y_sq = y.pow(2).sum(dim=1, keepdim=True).t() # (1, K) - return (x_sq + y_sq - 2.0 * x @ y.t()).clamp(min=0.0) - - @torch.no_grad() def _sinkhorn( cost: torch.Tensor, @@ -178,7 +160,7 @@ def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: codebook = self.embedding.weight # (n_embed, D) if self.distance_type == "l2": - distances = _squared_euclidean_distance(x, codebook) + distances = torch.cdist(x, codebook, p=2).pow(2) elif self.distance_type == "cosine": x_norm = F.normalize(x, p=2, dim=1) codebook_norm = F.normalize(codebook, p=2, dim=1) diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index e37ac8029..3ce8f936d 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -17,19 +17,20 @@ from tzrec.modules.sid.types import QuantizeForwardMode from tzrec.modules.sid.vector_quantize import ( VectorQuantizeLayer, - _squared_euclidean_distance, ) class SquaredEuclideanDistanceTest(unittest.TestCase): - """Tests for the squared-L2 distance helper used by VectorQuantizeLayer.""" + """Tests the l2 path of ``VectorQuantizeLayer._compute_distances`` (cdist²).""" - def test_squared_euclidean_distance(self) -> None: + def test_l2_compute_distances(self) -> None: + layer = VectorQuantizeLayer(embed_dim=2, n_embed=2, distance_type="l2") + # Pin the codebook to (0,0) and (0,1) so distances are exact. + layer.embedding.weight.data.copy_(torch.tensor([[0.0, 0.0], [0.0, 1.0]])) x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) - y = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) - d = _squared_euclidean_distance(x, y) + d = layer._compute_distances(x) self.assertEqual(d.shape, (2, 2)) - # row0: dist to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + # row0: dist² to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) From b33401b664558183d8af523f8b050a649d389d2d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 11:18:55 +0000 Subject: [PATCH 113/129] [chore] bump version to 1.2.21 Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index b9f275da8..c9acfa881 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.20" +__version__ = "1.2.21" From 54071344bc7c0619dc26d92b83a86245cc26b35f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Mon, 22 Jun 2026 11:36:08 +0000 Subject: [PATCH 114/129] [refactor] SID: simplify per review + give BaseSidModel logic a test home - residual_vector_quantizer: drop the no-op `distance_types = [d]*n_layers` list (distance_type is uniform; pass it directly). - infonce_loss: compute pair_mask.float()/n_valid once in forward() instead of re-deriving them in each of the 3 masked-CE calls. - vector_quantize: fix a stale docstring (the commitment loss moved to tzrec.loss.commitment_loss.CommitmentLoss; the named method no longer exists). - test placement: fold the stale-named SquaredEuclideanDistanceTest into VectorQuantizeTest (it tests the layer, not a removed helper); add sid_model_test.py covering the recon_loss factory + _masked_mean, which were only exercised end-to-end through the SidRqvae/SidRqkmeans subclass tests. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/infonce_loss.py | 20 +++--- tzrec/models/sid_model_test.py | 71 +++++++++++++++++++ .../modules/sid/residual_vector_quantizer.py | 4 +- tzrec/modules/sid/vector_quantize.py | 7 +- tzrec/modules/sid/vector_quantize_test.py | 8 +-- 5 files changed, 89 insertions(+), 21 deletions(-) create mode 100644 tzrec/models/sid_model_test.py diff --git a/tzrec/loss/infonce_loss.py b/tzrec/loss/infonce_loss.py index 703c737fd..64cb3ea2b 100644 --- a/tzrec/loss/infonce_loss.py +++ b/tzrec/loss/infonce_loss.py @@ -85,7 +85,8 @@ def _masked_cross_entropy( logits_a: torch.Tensor, logits_b: torch.Tensor, safe_labels: torch.Tensor, - pair_mask: torch.Tensor, + pair_mask_f: torch.Tensor, + n_valid: torch.Tensor, ) -> torch.Tensor: """Masked cross-entropy on column-masked logits, row-masked average. @@ -93,7 +94,8 @@ def _masked_cross_entropy( logits_a: (B, B_global) column-masked logits (view-a branch). logits_b: (B, B_global) column-masked logits (view-b branch). safe_labels: (B,) labels with non-pair rows fallback to a safe col. - pair_mask: (B,) bool, True = pair row. + pair_mask_f: (B,) float pair mask (1.0 = pair row). + n_valid: scalar pair-row count, clamped to >= 1. """ ce_a = F.cross_entropy(logits_a, safe_labels, reduction="none") ce_b = F.cross_entropy(logits_b, safe_labels, reduction="none") @@ -101,9 +103,7 @@ def _masked_cross_entropy( ce_a = torch.nan_to_num(ce_a, nan=0.0) ce_b = torch.nan_to_num(ce_b, nan=0.0) - # Only pair rows contribute; clamp(min=1) keeps a no-pair batch at 0. - n_valid = pair_mask.float().sum().clamp(min=1) - return ((ce_a + ce_b) * pair_mask.float()).sum() / (2 * n_valid) + return ((ce_a + ce_b) * pair_mask_f).sum() / (2 * n_valid) def forward( self, @@ -172,15 +172,17 @@ def forward( fallback = pair_mask.long().argmax() # first pair sample index safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels)) - # --- Masked CE for three loss groups --- + # --- Masked CE for three loss groups (shared row mask + valid count) --- + pair_mask_f = pair_mask.float() + n_valid = pair_mask_f.sum().clamp(min=1) loss_self = self._masked_cross_entropy( - logits_a_self, logits_b_self, safe_labels, pair_mask + logits_a_self, logits_b_self, safe_labels, pair_mask_f, n_valid ) loss_ori = self._masked_cross_entropy( - logits_a_ori, logits_b_ori, safe_labels, pair_mask + logits_a_ori, logits_b_ori, safe_labels, pair_mask_f, n_valid ) loss_cl = self._masked_cross_entropy( - logits_a_cl, logits_b_cl, safe_labels, pair_mask + logits_a_cl, logits_b_cl, safe_labels, pair_mask_f, n_valid ) loss = (loss_self + loss_ori + loss_cl) / 3 diff --git a/tzrec/models/sid_model_test.py b/tzrec/models/sid_model_test.py new file mode 100644 index 000000000..2f632bbb3 --- /dev/null +++ b/tzrec/models/sid_model_test.py @@ -0,0 +1,71 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.models.sid_model import _masked_mean, recon_loss + + +class ReconLossTest(unittest.TestCase): + """Tests for the shared ``BaseSidModel`` reconstruction-distance factory.""" + + def test_l2_is_per_row_mse(self) -> None: + d = recon_loss("l2")(torch.ones(3, 4), torch.zeros(3, 4)) + self.assertEqual(d.shape, (3,)) + torch.testing.assert_close(d, torch.ones(3)) # mean of 1^2 over dim -1 + + def test_l1_is_per_row_mae(self) -> None: + d = recon_loss("l1")(torch.ones(2, 5), torch.zeros(2, 5)) + torch.testing.assert_close(d, torch.ones(2)) + + def test_cos_is_one_minus_cosine(self) -> None: + x = torch.tensor([[1.0, 0.0]]) + # identical vectors -> cosine 1 -> distance 0 + d = recon_loss("cos")(x, x.clone()) + torch.testing.assert_close(d, torch.zeros(1), atol=1e-6, rtol=0) + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_each_type_finite_and_backprops(self, recon_type) -> None: + x_hat = torch.randn(4, 6, requires_grad=True) + loss = recon_loss(recon_type)(x_hat, torch.randn(4, 6)).mean() + self.assertTrue(torch.isfinite(loss)) + loss.backward() # grad must flow back to the (decoder) input + self.assertIsNotNone(x_hat.grad) + + def test_unknown_type_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "recon_type"): + recon_loss("nope") + + +class MaskedMeanTest(unittest.TestCase): + """Tests for the shared ``BaseSidModel`` masked-mean reduction.""" + + def test_no_mask_is_plain_mean(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + torch.testing.assert_close(_masked_mean(x), x.mean()) + + def test_mask_averages_over_valid_rows_only(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + mask = torch.tensor([True, False, True, False]) + torch.testing.assert_close(_masked_mean(x, mask), torch.tensor(2.0)) # (1+3)/2 + + def test_empty_mask_is_zero_not_nan(self) -> None: + out = _masked_mean( + torch.tensor([1.0, 2.0, 3.0]), torch.zeros(3, dtype=torch.bool) + ) + self.assertEqual(out.item(), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index d6e7fa2aa..072e5da45 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -146,15 +146,13 @@ def __init__( if is_gumbel and rotation_trick: logger.warning("gumbel_softmax: rotation_trick has no effect; ignoring.") - distance_types = [distance_type] * n_layers - self.layers = nn.ModuleList( [ VectorQuantizeLayer( embed_dim=embed_dim, n_embed=self.n_embed_list[i], forward_mode=mode_enum, - distance_type=distance_types[i], + distance_type=distance_type, use_sinkhorn=use_sinkhorn, sinkhorn_iters=sinkhorn_iters, sinkhorn_epsilon=sinkhorn_epsilon, diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 6326b7fd0..7b50e887a 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -87,9 +87,10 @@ class VectorQuantizeLayer(QuantizeLayer): A gradient-trained ``nn.Embedding`` codebook (the VQ ``QuantizeLayer``), sibling of the K-Means backend's ``KMeansQuantizeLayer``. Maps inputs to a - codebook entry via :meth:`quantize`. Loss-free: the commitment loss lives in - :meth:`ResidualVectorQuantizer._single_commitment_loss`. Sinkhorn - optimal-transport assignment optionally balances codebook usage in training. + codebook entry via :meth:`quantize`. Loss-free: the commitment loss is + computed model-side by :class:`tzrec.loss.commitment_loss.CommitmentLoss` + over the quantizer's per-layer ``latents``. Sinkhorn optimal-transport + assignment optionally balances codebook usage in training. Args: embed_dim (int): dimension of each codebook embedding. diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index 3ce8f936d..da9cd2e74 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -20,8 +20,8 @@ ) -class SquaredEuclideanDistanceTest(unittest.TestCase): - """Tests the l2 path of ``VectorQuantizeLayer._compute_distances`` (cdist²).""" +class VectorQuantizeTest(unittest.TestCase): + """Tests for a single VectorQuantizeLayer layer.""" def test_l2_compute_distances(self) -> None: layer = VectorQuantizeLayer(embed_dim=2, n_embed=2, distance_type="l2") @@ -33,10 +33,6 @@ def test_l2_compute_distances(self) -> None: # row0: dist² to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) - -class VectorQuantizeTest(unittest.TestCase): - """Tests for a single VectorQuantizeLayer layer.""" - @parameterized.expand( [ ("ste_l2", QuantizeForwardMode.STE, "l2", True), From dd576514587294a182bd56383a1f0e762d060e14 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 23 Jun 2026 02:31:00 +0000 Subject: [PATCH 115/129] =?UTF-8?q?[bugfix]=20SID:=20restore=20codebook=20?= =?UTF-8?q?gradient=20=E2=80=94=20RVQ=20quantize=20returns=20raw=20vector?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Regression from review item D5 ("derive apply_ste from forward_mode == STE"). In test/sid_abstract the RVQ residual walk called quantize(apply_ste=False), so each per-layer step returned the RAW codebook vector, which flows into the cumulative `latents` and carries gradient to the codebook. D5 dropped that param and hard-wired the per-layer STE wrap to fire whenever forward_mode==STE, so inside the (input-detached) walk `quantized = x + (q - x).detach()` detached the codebook from `latents`. The commitment loss's codebook term (loss2, ||z_e.detach() - latents||^2) then had ZERO gradient to the codebook: it froze at init, so as the encoder trained the commitment loss grew unbounded (recon stayed fine — the aggregate STE still trains encoder+decoder around the frozen codebook). Symptom: commitment_loss 8.7 -> 33 while recon ~0.01. Fix: VectorQuantizeLayer.quantize() returns the raw codebook vector; the single straight-through estimator is applied on the aggregate in ResidualVectorQuantizer.forward (the only production user of the layer). This matches test/sid_abstract's effective behavior (apply_ste=False for the walk). Add a regression test asserting the RVQ's `latents` carries codebook gradient, and update the standalone layer test to the new (raw-vector) contract. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../sid/residual_vector_quantizer_test.py | 20 +++++++++++++++++++ tzrec/modules/sid/vector_quantize.py | 15 +++++++------- tzrec/modules/sid/vector_quantize_test.py | 9 +++++---- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 4aee34322..6da13f490 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -81,6 +81,26 @@ def test_ste_codebook_grad_is_detached_on_recon_path(self) -> None: cb_grad = rvq.layers[0].embedding.weight.grad self.assertTrue(cb_grad is None or cb_grad.abs().sum().item() == 0.0) + def test_ste_codebook_grad_flows_via_commitment_latents(self) -> None: + # The codebook trains via the commitment loss, which consumes ``latents``, + # so backward through latents MUST reach the codebook. Regression: a + # per-layer STE wrap once detached the codebook from latents, freezing it + # at init (commitment loss then grew unbounded while recon stayed fine). + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + rvq(torch.randn(16, 8)).latents.sum().backward() + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertIsNotNone(cb_grad) + self.assertGreater(cb_grad.abs().sum().item(), 0.0) + class FaissResidualKmeansTest(unittest.TestCase): """Tests for the FAISS residual K-Means warm-start helper.""" diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 7b50e887a..e3d8a4296 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -232,15 +232,14 @@ def quantize(self, x: torch.Tensor) -> QuantizeOutput: return QuantizeOutput(embeddings=emb, ids=ids) # STE / eval: nearest-neighbour assignment under no_grad, one codebook - # gather. In STE training, wrap with the straight-through estimator so - # grad reaches the encoder. (Under the RVQ residual walk the input is - # detached, so this per-layer wrap is a numeric no-op and the aggregate - # STE in ResidualVectorQuantizer.forward carries the gradient.) + # gather. Return the RAW codebook vector (grad-carrying to the codebook) + # so the residual quantizer's cumulative ``latents`` trains the codebook + # via the commitment loss. The encoder straight-through gradient is + # applied once on the aggregate in ``ResidualVectorQuantizer.forward``; a + # per-layer STE wrap here would detach the codebook from ``latents`` and + # leave it frozen at init. ids = self._find_nearest_embedding(x) - quantized = self.embedding(ids) - if self.training and self.forward_mode == QuantizeForwardMode.STE: - quantized = x + (quantized - x).detach() # grad passes to x - return QuantizeOutput(embeddings=quantized, ids=ids) + return QuantizeOutput(embeddings=self.embedding(ids), ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: """Return the codebook table, shape (n_embed, embed_dim).""" diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py index da9cd2e74..eab40e04b 100644 --- a/tzrec/modules/sid/vector_quantize_test.py +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -99,16 +99,17 @@ def test_sinkhorn_epsilon_must_be_positive(self) -> None: embed_dim=8, n_embed=16, use_sinkhorn=True, sinkhorn_epsilon=0.0 ) - def test_train_forward_backward_reaches_input(self) -> None: + def test_train_forward_backward_reaches_codebook(self) -> None: torch.manual_seed(0) vq = VectorQuantizeLayer(embed_dim=8, n_embed=16, use_sinkhorn=False) vq.train() x = torch.randn(5, 8, requires_grad=True) out = vq.quantize(x) out.embeddings.sum().backward() - # STE routes gradient back through x. - self.assertIsNotNone(x.grad) - self.assertTrue(torch.isfinite(x.grad).all()) + # The layer returns the raw codebook vector, so gradient reaches the + # codebook (the encoder STE is applied on the aggregate by the RVQ). + self.assertIsNotNone(vq.embedding.weight.grad) + self.assertTrue(torch.isfinite(vq.embedding.weight.grad).all()) def test_eval_forward_is_plain_lookup(self) -> None: torch.manual_seed(0) From b20352038076947dd87cf7cfa773734b52b254fc Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 23 Jun 2026 06:46:59 +0000 Subject: [PATCH 116/129] [chore] SID: drop redundant/misplaced quantizer tests + fix stale STE comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Test audit (all in residual_vector_quantizer_test.py): - drop test_raises_on_too_few_points — duplicate of the N>=K guard already owned by kmeans_quantize_test.py (the test's own comment admits it comes from the shared faiss_kmeans_fit primitive). - drop test_decode_codes_shared_base — decode_codes is a base ResidualQuantizer method; residual_quantizer_test.py covers it (shape + sum + dtype, stronger). - drop test_get_codes_no_grad — shape-only and mis-named; get_codes shape is covered by the base walk test, and the eval no_grad contract by the retained test_forward_get_codes_consistent_eval. Comment audit: - fix the _quantize_layer comment that claimed the per-layer STE wrap is a "numeric no-op" — the codebook-gradient bugfix removed that wrap, so the layer returns the raw codebook vector (the single STE is on the aggregate). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../modules/sid/residual_vector_quantizer.py | 9 +++++---- .../sid/residual_vector_quantizer_test.py | 19 ------------------- 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 072e5da45..194e96212 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -298,10 +298,11 @@ def _quantize_layer( emb (Tensor): the raw codebook vector (STE/eval) or the soft embedding (Gumbel), with grad, shape (B, D). """ - # On the STE residual walk the residual is detached, so the layer's - # straight-through wrap is a numeric no-op; the real STE gradient comes - # from the aggregate STE in :meth:`forward`. Gumbel returns the soft - # embedding that carries grad directly. + # On the STE residual walk the residual is detached and the layer + # returns the raw codebook vector (grad-carrying on the codebook, no + # per-layer STE wrap); the encoder STE gradient is applied once on the + # aggregate in :meth:`forward`. Gumbel returns the soft embedding that + # carries grad directly. out = self.layers[layer_idx].quantize(residual) return out.ids, out.embeddings diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py index 6da13f490..ec6a062c9 100644 --- a/tzrec/modules/sid/residual_vector_quantizer_test.py +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -122,16 +122,6 @@ def test_faiss_residual_kmeans_per_layer_centers(self) -> None: # Centroids come back on the input device (CPU fit, device-preserving). self.assertEqual(centers[0].device, samples.device) - def test_raises_on_too_few_points(self) -> None: - # Gained from the shared faiss_kmeans_fit primitive: a clear N>=K error - # before faiss's opaque C++ throw. - try: - import faiss # noqa: F401 - except ImportError: - self.skipTest("faiss not installed") - with self.assertRaisesRegex(RuntimeError, "need >= 8 points"): - faiss_residual_kmeans(torch.randn(4, 6), [8]) - class ResidualVQBranchTest(unittest.TestCase): """Coverage for the rotation-trick STE branch and the kmeans-init guard.""" @@ -205,15 +195,6 @@ def test_forward_output(self) -> None: self.assertEqual(out.latents.shape, (5, 3, 8)) self.assertTrue(torch.isfinite(out.latents).all()) - def test_decode_codes_shared_base(self) -> None: - codes = torch.randint(0, 16, (5, 3)) - recon = self.rvq.decode_codes(codes) - self.assertEqual(recon.shape, (5, 8)) - - def test_get_codes_no_grad(self) -> None: - codes = self.rvq.get_codes(torch.randn(4, 8)) - self.assertEqual(codes.shape, (4, 3)) - def test_forward_get_codes_consistent_eval(self) -> None: """get_codes (shared base walk) matches forward's ids in eval.""" self.rvq.eval() From 4239390ad2c971c78156b37b27bec99ef0d184ae Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 23 Jun 2026 07:13:09 +0000 Subject: [PATCH 117/129] [refactor] SID: extract CLIP wiring into SidRqvae._init_clip Move the ~40-line CLIP read + validation block out of __init__ into a private _init_clip() helper, so __init__ reads as a clean sequence (super -> clip -> encoder/decoder/quantizer). Behavior-identical (same attributes set, same fail-fast ValueErrors); the helper flattens the nesting with an early return when CLIP is off, and documents that it must run after super().__init__() (it needs embedding_group / _input_dim). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 90 ++++++++++++++++++++++----------------- 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 9b7520b71..601f9e182 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -68,46 +68,7 @@ def __init__( cfg = self._model_config # SidRqvae proto message - self._use_clip = cfg.HasField("clip_config") - self._clip_feature_group = ( - cfg.clip_config.clip_feature_group if self._use_clip else None - ) - self._clip_pair_feature_group = ( - cfg.clip_config.clip_pair_feature_group if self._use_clip else None - ) - has_clip_obj = any( - lc.WhichOneof("sid_loss") == "sid_clip_loss" - for lc in self._base_model_config.losses - ) - if self._use_clip != has_clip_obj: - raise ValueError( - "clip_config (model structure) and a sid_clip_loss entry in " - "losses (the objective) must be set together; got " - f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" - ) - if self._use_clip: - for grp in (self._clip_feature_group, self._clip_pair_feature_group): - if not self.embedding_group.has_group(grp): - raise ValueError( - f"clip group {grp!r} is not in model_config.feature_groups " - f"{self.embedding_group.group_names()}" - ) - clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) - if clip_dim != self._input_dim: - raise ValueError( - f"clip_feature_group {self._clip_feature_group!r} has total " - f"dim {clip_dim}, but it is encoded by the same encoder as " - f"the main feature_group (dim {self._input_dim}); the two " - "must match" - ) - pair_dim = self.embedding_group.group_total_dim( - self._clip_pair_feature_group - ) - if pair_dim != 1: - raise ValueError( - f"clip_pair_feature_group {self._clip_pair_feature_group!r} " - f"must be a single dim-1 raw flag, got total dim {pair_dim}" - ) + self._init_clip() embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -163,6 +124,55 @@ def __init__( self._use_clip, ) + def _init_clip(self) -> None: + """Read and validate the CLIP dual-encoder wiring (``clip_config``). + + Sets ``_use_clip`` and the paired / pair-flag group names, and enforces: + ``clip_config`` (structure) and a ``sid_clip_loss`` entry (objective) are + set together; the paired group exists and matches ``input_dim`` (it shares + the encoder); the pair-flag group is a single dim-1 raw flag. Must run + after ``super().__init__()`` — it needs ``embedding_group`` / ``_input_dim``. + """ + cfg = self._model_config + self._use_clip = cfg.HasField("clip_config") + self._clip_feature_group = ( + cfg.clip_config.clip_feature_group if self._use_clip else None + ) + self._clip_pair_feature_group = ( + cfg.clip_config.clip_pair_feature_group if self._use_clip else None + ) + has_clip_obj = any( + lc.WhichOneof("sid_loss") == "sid_clip_loss" + for lc in self._base_model_config.losses + ) + if self._use_clip != has_clip_obj: + raise ValueError( + "clip_config (model structure) and a sid_clip_loss entry in " + "losses (the objective) must be set together; got " + f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" + ) + if not self._use_clip: + return + for grp in (self._clip_feature_group, self._clip_pair_feature_group): + if not self.embedding_group.has_group(grp): + raise ValueError( + f"clip group {grp!r} is not in model_config.feature_groups " + f"{self.embedding_group.group_names()}" + ) + clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) + if clip_dim != self._input_dim: + raise ValueError( + f"clip_feature_group {self._clip_feature_group!r} has total " + f"dim {clip_dim}, but it is encoded by the same encoder as the " + f"main feature_group (dim {self._input_dim}); the two must match" + ) + pair_dim = self.embedding_group.group_total_dim(self._clip_pair_feature_group) + if pair_dim != 1: + raise ValueError( + f"clip_pair_feature_group {self._clip_pair_feature_group!r} must " + f"be a single dim-1 raw flag, got total dim {pair_dim}" + ) + def _encode(self, x: torch.Tensor) -> torch.Tensor: """Encode. (B, input_dim) -> (B, embed_dim).""" return self._encoder(x) From c4cbd64566ee26b4491292dc2f76841203d01eb7 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Tue, 23 Jun 2026 11:13:00 +0000 Subject: [PATCH 118/129] [chore] bump version to 1.2.22 master advanced to 1.2.21; bump past the collision to stay ahead. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tzrec/version.py b/tzrec/version.py index c9acfa881..489ed9a3a 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.21" +__version__ = "1.2.22" From a25be45953804bd6b8fee72bdd353c7523fd7182 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 24 Jun 2026 09:23:58 +0000 Subject: [PATCH 119/129] [chore] SID: drop comments that describe an absence / refactor history MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These describe what the code does NOT do (or used to do) rather than the code: - sid_model.proto: the two "(input_dim is not configured here — ...)" notes on SidRqvae / SidRqkmeans (input_dim is simply not a field). - residual_vector_quantizer.py / types.py: "the commitment loss is no longer computed inside the quantizer" — trimmed to just what `latents` is for. - sid_rqkmeans_test.py: "input_dim is no longer a config knob ... rather than an explicit input_dim=0" history note on the 0-dim-group guard test. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqkmeans_test.py | 6 +----- tzrec/modules/sid/residual_vector_quantizer.py | 3 +-- tzrec/modules/sid/types.py | 7 +++---- tzrec/protos/models/sid_model.proto | 5 ----- 4 files changed, 5 insertions(+), 16 deletions(-) diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 86cbcc5c8..128b5acbc 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -142,11 +142,7 @@ def test_init_raises_on_zero_codebook_entry(self) -> None: self._create_model(codebook=[16, 0]) def test_init_raises_on_zero_dim_feature_group(self) -> None: - """A feature group with total dim 0 fails fast (derived input_dim < 1). - - input_dim is no longer a config knob — it is derived from the group, so - the guard now fires via a 0-dim group rather than an explicit input_dim=0. - """ + """A feature group with total dim 0 fails fast (derived input_dim < 1).""" with self.assertRaisesRegex(ValueError, "must be >= 1"): self._create_model(input_dim=0) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 194e96212..dcff04f66 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -336,8 +336,7 @@ def forward( cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) # Expose the per-layer cumulative quantized vectors (grad-carrying on the - # codebook side) so the model-side CommitmentLoss can consume them; the - # commitment loss is no longer computed inside the quantizer. + # codebook side) so the model-side CommitmentLoss can consume them. latents = torch.stack(cumulative, dim=1) # (B, n_layers, D) # Aggregate STE (STE only; Gumbel already carries grad). diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py index 7b110f023..975935d22 100644 --- a/tzrec/modules/sid/types.py +++ b/tzrec/modules/sid/types.py @@ -44,10 +44,9 @@ class QuantizeOutput(NamedTuple): class ResidualQuantizerOutput(NamedTuple): """Output of the residual quantization module (RQ-VAE backend). - The commitment loss is no longer computed inside the quantizer; the per-layer - cumulative quantized vectors are exposed as ``latents`` so the model-side - commitment loss (:class:`~tzrec.loss.commitment_loss.CommitmentLoss`) can - consume them. + The per-layer cumulative quantized vectors are exposed as ``latents`` so the + model-side commitment loss + (:class:`~tzrec.loss.commitment_loss.CommitmentLoss`) can consume them. Attributes: cluster_ids (Tensor): codebook indices per layer, shape (B, n_layers). diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index fd5b4d9b5..e43703ce8 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -40,8 +40,6 @@ message ClipConfig { message SidRqvae { // === Network structure === - // (input_dim is not configured here — it is derived from the total - // dimension of the `feature_group` built by the model's EmbeddingGroup.) // Quantization latent dimension (encoder output / codebook dim). optional uint32 embed_dim = 2 [default = 64]; // Encoder hidden layer sizes, e.g. [256, 128]. @@ -97,9 +95,6 @@ message SidRqvae { } message SidRqkmeans { - // (input_dim is not configured here — K-Means runs directly on the raw - // embeddings of the `feature_group` built by the model's EmbeddingGroup, - // whose total dim is the K-Means dimension.) // Per-layer cluster counts, e.g. [256, 256, 256]. // List length is the number of residual quantization layers. Entries // may differ per layer (non-uniform codebooks such as [256, 512, 1024] From 46b5ac2da4f752c94a15dea23758a00ac0da4376 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 24 Jun 2026 09:35:22 +0000 Subject: [PATCH 120/129] [refactor] SID: default clip group names to None up front in _init_clip Set _clip_feature_group / _clip_pair_feature_group to None at the top, then assign the real values after the `if not self._use_clip: return` (where CLIP is guaranteed on). Drops the two `if self._use_clip else None` ternaries and never reads cfg.clip_config.* when clip_config is unset. Behavior-identical; the attributes are still always defined and the clip_config<->sid_clip_loss consistency check still runs before the return. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 601f9e182..9a3d69d86 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -134,13 +134,10 @@ def _init_clip(self) -> None: after ``super().__init__()`` — it needs ``embedding_group`` / ``_input_dim``. """ cfg = self._model_config + # Default to no CLIP; the group names stay None unless clip_config is set. + self._clip_feature_group = None + self._clip_pair_feature_group = None self._use_clip = cfg.HasField("clip_config") - self._clip_feature_group = ( - cfg.clip_config.clip_feature_group if self._use_clip else None - ) - self._clip_pair_feature_group = ( - cfg.clip_config.clip_pair_feature_group if self._use_clip else None - ) has_clip_obj = any( lc.WhichOneof("sid_loss") == "sid_clip_loss" for lc in self._base_model_config.losses @@ -153,6 +150,8 @@ def _init_clip(self) -> None: ) if not self._use_clip: return + self._clip_feature_group = cfg.clip_config.clip_feature_group + self._clip_pair_feature_group = cfg.clip_config.clip_pair_feature_group for grp in (self._clip_feature_group, self._clip_pair_feature_group): if not self.embedding_group.has_group(grp): raise ValueError( From 08c1bb212b5adf699a633ce351bb6ca0604e3f5f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 24 Jun 2026 09:52:50 +0000 Subject: [PATCH 121/129] [refactor] SID: move CLIP temperatures into MaskedInfoNCELoss The three learnable contrastive temperatures (logit_scale_self/cl/logit_scale), their init/cap constants and the clamp+exp were on BaseSidModel and passed into the loss through the feats dict. They are loss-internal hyperparameters, so move them into MaskedInfoNCELoss: it declares the nn.Parameters in __init__ and does the clamp(<= ln 100) + exp inside forward. Because the loss is stored in the model's _loss_modules (an nn.ModuleDict), the parameters are still registered, trained and checkpointed exactly as before. BaseSidModel loses three params, the scaled() helper, the two module constants and the numpy/nn imports; the model just hands the loss the four embeds + the pair mask. Behavior-identical (same init, same clamp+exp); the overflow and large-scale tests now reach the temperatures on the loss module. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/infonce_loss.py | 26 ++++++++++++++++++++------ tzrec/loss/infonce_loss_test.py | 30 ++++++++++-------------------- tzrec/models/sid_model.py | 21 +-------------------- tzrec/models/sid_rqvae_test.py | 14 ++++++++------ 4 files changed, 39 insertions(+), 52 deletions(-) diff --git a/tzrec/loss/infonce_loss.py b/tzrec/loss/infonce_loss.py index 64cb3ea2b..5670da36f 100644 --- a/tzrec/loss/infonce_loss.py +++ b/tzrec/loss/infonce_loss.py @@ -11,14 +11,22 @@ """Masked InfoNCE contrastive loss with distributed all-gather support.""" +import math from typing import Dict, List, Optional import torch import torch.distributed as dist import torch.distributed.nn as dist_nn +from torch import nn from torch.nn import functional as F from torch.nn.modules.loss import _Loss +# CLIP temperature init (reference CLIP: log(1 / 0.07)) and the cap applied +# before ``exp`` (reference CLIP clamps to ln(100)): an unbounded temperature +# would overflow to +Inf -> NaN grad -> corrupt param. +_LOGIT_SCALE_INIT = math.log(1 / 0.07) +_LOGIT_SCALE_MAX = math.log(100) + class MaskedInfoNCELoss(_Loss): """Masked InfoNCE contrastive loss for mixed (paired + non-paired) batches. @@ -35,9 +43,9 @@ class MaskedInfoNCELoss(_Loss): 'embed_b': reconstructed (decoder) output of view b 'embed_a_ori': original embedding of view a 'embed_b_ori': original embedding of view b - 'logit_scale_self': scalar temperature: recon-a vs recon-b - 'logit_scale_cl': scalar temperature: recon vs same-view original - 'logit_scale': scalar temperature: recon vs counterpart original + + The three contrastive temperatures (self/ori/cl) are learnable parameters + owned by this module; ``forward`` clamps (to <= ln(100)) and ``exp``s them. Output dict keys: 'loss': scalar mean of the three contrastive losses (self/ori/cl) @@ -48,6 +56,11 @@ def __init__(self) -> None: self.labels: Optional[torch.Tensor] = None self.last_local_batch_size: Optional[int] = None self._rank = dist.get_rank() if dist.is_initialized() else 0 + # Learnable contrastive temperatures, one per group (self / ori / cl); + # registered here so the InfoNCE module is self-contained. + self.logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self.logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self.logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) @staticmethod def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: @@ -120,9 +133,10 @@ def forward( embed_b = outputs["embed_b"] embed_a_ori = outputs["embed_a_ori"] embed_b_ori = outputs["embed_b_ori"] - logit_scale = outputs["logit_scale"] - logit_scale_self = outputs["logit_scale_self"] - logit_scale_cl = outputs["logit_scale_cl"] + # Clamp before exp so a large temperature can't overflow to +Inf -> NaN. + logit_scale = self.logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp() + logit_scale_self = self.logit_scale_self.clamp(max=_LOGIT_SCALE_MAX).exp() + logit_scale_cl = self.logit_scale_cl.clamp(max=_LOGIT_SCALE_MAX).exp() local_batch_size = embed_a.size(0) diff --git a/tzrec/loss/infonce_loss_test.py b/tzrec/loss/infonce_loss_test.py index a395d1845..bedd5cdad 100644 --- a/tzrec/loss/infonce_loss_test.py +++ b/tzrec/loss/infonce_loss_test.py @@ -12,7 +12,6 @@ import os import unittest -import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -34,15 +33,11 @@ class MaskedInfoNCELossTest(unittest.TestCase): def _features(self, B: int, D: int) -> dict: torch.manual_seed(0) - scale = torch.tensor(np.log(1 / 0.07)).exp() return { "embed_a": torch.randn(B, D, requires_grad=True), "embed_b": torch.randn(B, D, requires_grad=True), "embed_a_ori": torch.randn(B, D), "embed_b_ori": torch.randn(B, D), - "logit_scale_self": scale, - "logit_scale_cl": scale, - "logit_scale": scale, } def test_forward_all_clip_finite(self) -> None: @@ -98,7 +93,6 @@ def test_recon_columns_excluded_from_negatives(self) -> None: torch.manual_seed(0) B, D = 4, 8 img = torch.randn(B, D) - scale = torch.tensor(10.0) mask = torch.tensor([True, True, False, False]) # rows 2,3 are recon def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): @@ -107,9 +101,6 @@ def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): "embed_b": txt, "embed_a_ori": img_ori, "embed_b_ori": txt_ori, - "logit_scale_self": scale, - "logit_scale_cl": scale, - "logit_scale": scale, } txt, txt_ori, img_ori = (torch.randn(B, D) for _ in range(3)) @@ -125,16 +116,19 @@ def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): def test_mask_holds_under_large_scale(self) -> None: # The column fill is finfo.min (below any real logit) rather than a - # hardcoded -1e4, so masking holds even when logit_scale is large and - # the *_ori operands are un-normalized (real logits can dwarf 1e4). - # Loss/grad must stay finite and acc valid; eval exercises the argmax. + # hardcoded -1e4, so masking holds even when the temperature is large and + # the *_ori operands are un-normalized (real logits can dwarf 1e4). The + # loss's internal clamp caps exp() at <= 100; loss/grad must stay finite. loss_fn = MaskedInfoNCELoss() + with torch.no_grad(): + for p in ( + loss_fn.logit_scale, + loss_fn.logit_scale_self, + loss_fn.logit_scale_cl, + ): + p.fill_(3000.0) loss_fn.eval() feats = self._features(6, 8) - big = torch.tensor(3000.0) - feats["logit_scale"] = big - feats["logit_scale_self"] = big - feats["logit_scale_cl"] = big feats["embed_a_ori"] = feats["embed_a_ori"] * 50 feats["embed_b_ori"] = feats["embed_b_ori"] * 50 mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) @@ -200,15 +194,11 @@ def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: device = _init(rank, world_size, port) torch.manual_seed(1234 + rank) B, D = 4, 8 - scale = torch.tensor(np.log(1 / 0.07)).exp().to(device) feats = { "embed_a": torch.randn(B, D, device=device, requires_grad=True), "embed_b": torch.randn(B, D, device=device, requires_grad=True), "embed_a_ori": torch.randn(B, D, device=device), "embed_b_ori": torch.randn(B, D, device=device), - "logit_scale_self": scale, - "logit_scale_cl": scale, - "logit_scale": scale, } mask = torch.ones(B, dtype=torch.bool, device=device) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index d8dc3df19..c87e5fa67 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -13,11 +13,9 @@ from typing import Any, Callable, Dict, List, Optional -import numpy as np import torch import torch.nn.functional as F import torchmetrics -from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature @@ -31,12 +29,6 @@ from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig -# Cap the CLIP temperatures before ``exp`` (reference CLIP clamps to ln(100)): -# an unbounded ``logit_scale`` overflows to +Inf -> NaN grad -> corrupt param. -_LOGIT_SCALE_MAX = float(np.log(100)) -# CLIP temperature init (reference CLIP: log(1 / 0.07)). -_LOGIT_SCALE_INIT = float(np.log(1 / 0.07)) - def recon_loss( recon_type: str, @@ -186,10 +178,7 @@ def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: commitment_type=cfg.commitment_type, ) elif loss_type == "sid_clip_loss": - # The three learnable contrastive temperatures + the InfoNCE module. - self._logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) - self._logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) - self._logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + # The InfoNCE module owns its learnable contrastive temperatures. self._loss_modules["sid_clip_loss"] = MaskedInfoNCELoss() else: raise ValueError( @@ -234,19 +223,11 @@ def _sid_loss_impl( ) return {"commitment_loss": loss} elif loss_type == "sid_clip_loss": - - def scaled(p: torch.Tensor) -> torch.Tensor: - # clamp before exp so a large temperature can't overflow to +Inf. - return p.clamp(max=_LOGIT_SCALE_MAX).exp() - feats = { "embed_a": predictions["embed_a"], "embed_b": predictions["embed_b"], "embed_a_ori": predictions["embed_a_ori"], "embed_b_ori": predictions["embed_b_ori"], - "logit_scale_self": scaled(self._logit_scale_self), - "logit_scale_cl": scaled(self._logit_scale_cl), - "logit_scale": scaled(self._logit_scale), } out = self._loss_modules["sid_clip_loss"](feats, predictions["pair_mask"]) return {"sid_clip_loss": out["loss"]} diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 54e7b33cd..2bef03506 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -432,19 +432,21 @@ def test_logit_scale_clamped_prevents_overflow(self) -> None: model = self._create_model(input_dim=input_dim, use_clip=True) model.train() model.init_loss() + # The temperatures live on the InfoNCE module that owns the clamp. + clip = model._loss_modules["sid_clip_loss"] with torch.no_grad(): - model._logit_scale_self.fill_(100.0) - model._logit_scale_cl.fill_(100.0) - model._logit_scale.fill_(100.0) + clip.logit_scale_self.fill_(100.0) + clip.logit_scale_cl.fill_(100.0) + clip.logit_scale.fill_(100.0) batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) self.assertTrue(torch.isfinite(losses["sid_clip_loss"])) sum(losses.values()).backward() for p in ( - model._logit_scale_self, - model._logit_scale_cl, - model._logit_scale, + clip.logit_scale_self, + clip.logit_scale_cl, + clip.logit_scale, ): self.assertIsNotNone(p.grad) self.assertTrue(torch.isfinite(p.grad).all()) From 714d3c0e003e392fdc71932270f8714393d5fd67 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Wed, 24 Jun 2026 09:58:26 +0000 Subject: [PATCH 122/129] [refactor] SID: restore the clamp+exp helper in MaskedInfoNCELoss /simplify follow-up to the temperature move: the previous commit inlined the clamp(<= ln 100) + exp idiom three times in forward(), dropping the single scaled() helper the pre-refactor sid_model.py had. Restore it as a _scaled() staticmethod so the cap constant + the clamp+exp contract live in one place. Also dedup the overflow test: hoist the three-temperature tuple into `scales` and loop over it for both fill_ and the grad assertions (matching the sibling infonce_loss_test). Behavior-identical. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/infonce_loss.py | 13 +++++++++---- tzrec/models/sid_rqvae_test.py | 12 ++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tzrec/loss/infonce_loss.py b/tzrec/loss/infonce_loss.py index 5670da36f..b90202d8f 100644 --- a/tzrec/loss/infonce_loss.py +++ b/tzrec/loss/infonce_loss.py @@ -62,6 +62,11 @@ def __init__(self) -> None: self.logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) self.logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + @staticmethod + def _scaled(logit_scale: torch.Tensor) -> torch.Tensor: + # Clamp before exp so a large temperature can't overflow to +Inf -> NaN. + return logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp() + @staticmethod def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: """All-gather tensors across workers with gradient support. @@ -133,10 +138,10 @@ def forward( embed_b = outputs["embed_b"] embed_a_ori = outputs["embed_a_ori"] embed_b_ori = outputs["embed_b_ori"] - # Clamp before exp so a large temperature can't overflow to +Inf -> NaN. - logit_scale = self.logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp() - logit_scale_self = self.logit_scale_self.clamp(max=_LOGIT_SCALE_MAX).exp() - logit_scale_cl = self.logit_scale_cl.clamp(max=_LOGIT_SCALE_MAX).exp() + # The three contrastive temperatures, clamped (<= ln 100) then exp'd. + logit_scale = self._scaled(self.logit_scale) + logit_scale_self = self._scaled(self.logit_scale_self) + logit_scale_cl = self._scaled(self.logit_scale_cl) local_batch_size = embed_a.size(0) diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 2bef03506..8bb5c7a54 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -434,20 +434,16 @@ def test_logit_scale_clamped_prevents_overflow(self) -> None: model.init_loss() # The temperatures live on the InfoNCE module that owns the clamp. clip = model._loss_modules["sid_clip_loss"] + scales = (clip.logit_scale_self, clip.logit_scale_cl, clip.logit_scale) with torch.no_grad(): - clip.logit_scale_self.fill_(100.0) - clip.logit_scale_cl.fill_(100.0) - clip.logit_scale.fill_(100.0) + for p in scales: + p.fill_(100.0) batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) self.assertTrue(torch.isfinite(losses["sid_clip_loss"])) sum(losses.values()).backward() - for p in ( - clip.logit_scale_self, - clip.logit_scale_cl, - clip.logit_scale, - ): + for p in scales: self.assertIsNotNone(p.grad) self.assertTrue(torch.isfinite(p.grad).all()) From 210bfabdc1c7c43e97244d29f217926a783945f8 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 06:10:31 +0000 Subject: [PATCH 123/129] [refactor] SID: de-CLIP rename + uniform loss-modularization Rename the SID losses to a modality-agnostic "contrastive" scheme, register all three SID losses uniformly in _loss_modules, and tidy feature_group handling. Behavior-preserving: same losses, gradients and init (the contrastive loss is bit-identical to the prior version on forward + both grads). Protos (loss.proto, models/sid_model.proto): - messages ReconLoss/CommitmentLoss/SidClipLoss -> SidReconLoss/ SidCommitmentLoss/SidContrastiveLoss; oneof field sid_clip_loss -> contrastive_loss (field numbers 6/7/8 unchanged). - ClipConfig -> ContrastiveConfig; clip_feature_group -> pair_feature_group; clip_pair_feature_group -> pair_flag_feature_group; clip_config -> contrastive_config. - feature_group: drop the "deep" default -> optional with single-group auto-detect (DLRM-style); multiple groups must name it explicitly. Losses: - new SidReconLoss(_Loss) (mse/l1/cos, reduction="none") replaces the recon_loss factory fn; recon is now a _loss_modules entry like the rest. - commitment_loss.py -> sid_commitment_loss.py, CommitmentLoss -> SidCommitmentLoss(_Loss) (aligned to _Loss for uniformity; behavior-neutral). - infonce_loss.py -> sid_contrastive_loss.py, MaskedInfoNCELoss -> SidContrastiveLoss: explicit (embed_a/b, embed_a/b_ori, pair_mask) args + scalar return (was a feats dict / {"loss": ...}); one batched all-gather (was two); the six logit/CE blocks DRY'd into a loop; logit_scale -> logit_scale_ori. Models: - sid_model.py: uniform _sid_loss_impl dispatch; _resolve_feature_group auto-detect; drop the recon factory + _recon_fn. - sid_rqvae.py: _init_contrastive, attr renames (_use_contrastive / _pair_feature_group / _pair_flag_feature_group), _predict_mixed locals (is_pair_raw / pair_mask); drop CLIP/image-text docstring framing. - sid_rqkmeans.py unchanged (inherits the renamed base). Tests/configs: colocated test renames + a new sid_recon_loss_test.py; the recon factory tests move out of sid_model_test.py; sid_rqvae_clip_mock.config -> sid_rqvae_contrastive_mock.config (new fields + explicit feature_group). Run gen_proto.sh after this change to regenerate the *_pb2 modules. Co-Authored-By: Claude Opus 4.8 (1M context) --- ...mitment_loss.py => sid_commitment_loss.py} | 6 +- ...ss_test.py => sid_commitment_loss_test.py} | 16 +- ...nfonce_loss.py => sid_contrastive_loss.py} | 123 ++++++------- ...s_test.py => sid_contrastive_loss_test.py} | 100 ++++++----- tzrec/loss/sid_recon_loss.py | 54 ++++++ tzrec/loss/sid_recon_loss_test.py | 52 ++++++ tzrec/models/sid_model.py | 113 ++++++------ tzrec/models/sid_model_test.py | 34 +--- tzrec/models/sid_rqvae.py | 93 +++++----- tzrec/models/sid_rqvae_test.py | 162 ++++++++++-------- .../modules/sid/residual_vector_quantizer.py | 2 +- tzrec/modules/sid/types.py | 2 +- tzrec/modules/sid/vector_quantize.py | 3 +- tzrec/protos/loss.proto | 20 +-- tzrec/protos/models/sid_model.proto | 36 ++-- ...nfig => sid_rqvae_contrastive_mock.config} | 32 ++-- 16 files changed, 469 insertions(+), 379 deletions(-) rename tzrec/loss/{commitment_loss.py => sid_commitment_loss.py} (95%) rename tzrec/loss/{commitment_loss_test.py => sid_commitment_loss_test.py} (83%) rename tzrec/loss/{infonce_loss.py => sid_contrastive_loss.py} (64%) rename tzrec/loss/{infonce_loss_test.py => sid_contrastive_loss_test.py} (72%) create mode 100644 tzrec/loss/sid_recon_loss.py create mode 100644 tzrec/loss/sid_recon_loss_test.py rename tzrec/tests/configs/{sid_rqvae_clip_mock.config => sid_rqvae_contrastive_mock.config} (66%) diff --git a/tzrec/loss/commitment_loss.py b/tzrec/loss/sid_commitment_loss.py similarity index 95% rename from tzrec/loss/commitment_loss.py rename to tzrec/loss/sid_commitment_loss.py index 5fdb7434a..6b2fd913f 100644 --- a/tzrec/loss/commitment_loss.py +++ b/tzrec/loss/sid_commitment_loss.py @@ -9,16 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""CommitmentLoss: VQ-VAE commitment loss for residual quantizers.""" +"""SidCommitmentLoss: VQ-VAE commitment loss for residual quantizers.""" from typing import Sequence import torch import torch.nn.functional as F -from torch import nn +from torch.nn.modules.loss import _Loss -class CommitmentLoss(nn.Module): +class SidCommitmentLoss(_Loss): """Commitment loss between the encoder output and the quantized vectors. Operates on a residual quantizer's per-layer cumulative quantized vectors diff --git a/tzrec/loss/commitment_loss_test.py b/tzrec/loss/sid_commitment_loss_test.py similarity index 83% rename from tzrec/loss/commitment_loss_test.py rename to tzrec/loss/sid_commitment_loss_test.py index 92a0bc09c..1054a4c27 100644 --- a/tzrec/loss/commitment_loss_test.py +++ b/tzrec/loss/sid_commitment_loss_test.py @@ -14,17 +14,17 @@ import torch from parameterized import parameterized -from tzrec.loss.commitment_loss import CommitmentLoss +from tzrec.loss.sid_commitment_loss import SidCommitmentLoss -class CommitmentLossTest(unittest.TestCase): - """Tests for the standalone CommitmentLoss module.""" +class SidCommitmentLossTest(unittest.TestCase): + """Tests for the standalone SidCommitmentLoss module.""" @parameterized.expand([("l2",), ("l1",), ("cos",)]) def test_branch_runs_and_backprops(self, commitment_type) -> None: """Each commitment_type runs end-to-end; grad reaches both operands.""" torch.manual_seed(0) - loss_fn = CommitmentLoss( + loss_fn = SidCommitmentLoss( latent_weight=(1.0, 0.5), commitment_type=commitment_type ) B, L, D = 4, 3, 8 @@ -43,20 +43,20 @@ def test_latent_weight_wrong_length_raises(self) -> None: """latent_weight must be exactly [w1, w2].""" for bad in ([1.0], [1.0, 0.5, 0.25]): with self.assertRaisesRegex(ValueError, "latent_weight"): - CommitmentLoss(latent_weight=bad) + SidCommitmentLoss(latent_weight=bad) def test_invalid_commitment_type_raises(self) -> None: """An unknown commitment_type is rejected.""" with self.assertRaisesRegex(AssertionError, "commitment_type"): - CommitmentLoss(commitment_type="bogus") + SidCommitmentLoss(commitment_type="bogus") def test_weights_scale_the_two_directions(self) -> None: """w1/w2 weight the encoder-toward-quant / quant-toward-encoder terms.""" torch.manual_seed(0) encoder_out = torch.randn(4, 8) latents = torch.randn(4, 3, 8) - base = CommitmentLoss(latent_weight=(1.0, 0.5), commitment_type="l2") - zero = CommitmentLoss(latent_weight=(0.0, 0.0), commitment_type="l2") + base = SidCommitmentLoss(latent_weight=(1.0, 0.5), commitment_type="l2") + zero = SidCommitmentLoss(latent_weight=(0.0, 0.0), commitment_type="l2") self.assertGreater(base(encoder_out, latents).item(), 0.0) self.assertEqual(zero(encoder_out, latents).item(), 0.0) diff --git a/tzrec/loss/infonce_loss.py b/tzrec/loss/sid_contrastive_loss.py similarity index 64% rename from tzrec/loss/infonce_loss.py rename to tzrec/loss/sid_contrastive_loss.py index b90202d8f..768df25c0 100644 --- a/tzrec/loss/infonce_loss.py +++ b/tzrec/loss/sid_contrastive_loss.py @@ -12,7 +12,7 @@ """Masked InfoNCE contrastive loss with distributed all-gather support.""" import math -from typing import Dict, List, Optional +from typing import List, Optional import torch import torch.distributed as dist @@ -28,27 +28,20 @@ _LOGIT_SCALE_MAX = math.log(100) -class MaskedInfoNCELoss(_Loss): - """Masked InfoNCE contrastive loss for mixed (paired + non-paired) batches. +class SidContrastiveLoss(_Loss): + """Masked InfoNCE pair-contrastive loss for mixed (paired + non-paired) batches. - Modality-agnostic: aligns two reconstructed "views" (``embed_a``/``embed_b``) - against each other and against their originals (``embed_a_ori``/ + Modality-agnostic: aligns two reconstructed "views" (``embed_a`` / ``embed_b``) + against each other and against their originals (``embed_a_ori`` / ``embed_b_ori``) with three symmetric InfoNCE terms (self/ori/cl). In a mixed batch, non-pair rows (``pair_mask=False``) must not contribute and must not serve as negatives; row/column masks achieve this without data-dependent branching (``torch.compile``-friendly). - Input dict keys (all embeddings shape (B, dim)): - 'embed_a': reconstructed (decoder) output of view a - 'embed_b': reconstructed (decoder) output of view b - 'embed_a_ori': original embedding of view a - 'embed_b_ori': original embedding of view b - - The three contrastive temperatures (self/ori/cl) are learnable parameters - owned by this module; ``forward`` clamps (to <= ln(100)) and ``exp``s them. - - Output dict keys: - 'loss': scalar mean of the three contrastive losses (self/ori/cl) + ``forward`` takes the four ``(B, dim)`` view embeddings plus the ``(B,)`` pair + mask and returns the scalar mean of the three contrastive terms. The three + temperatures (self/ori/cl) are learnable parameters owned by this module; + ``forward`` clamps (to <= ln(100)) and ``exp``s them. """ def __init__(self) -> None: @@ -57,10 +50,10 @@ def __init__(self) -> None: self.last_local_batch_size: Optional[int] = None self._rank = dist.get_rank() if dist.is_initialized() else 0 # Learnable contrastive temperatures, one per group (self / ori / cl); - # registered here so the InfoNCE module is self-contained. + # registered here so the loss module is self-contained. self.logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) self.logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) - self.logit_scale = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self.logit_scale_ori = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) @staticmethod def _scaled(logit_scale: torch.Tensor) -> torch.Tensor: @@ -125,22 +118,27 @@ def _masked_cross_entropy( def forward( self, - outputs: Dict[str, torch.Tensor], + embed_a: torch.Tensor, + embed_b: torch.Tensor, + embed_a_ori: torch.Tensor, + embed_b_ori: torch.Tensor, pair_mask: torch.Tensor, - ) -> Dict[str, torch.Tensor]: - """Forward with the pair mask. + ) -> torch.Tensor: + """Compute the masked pair-contrastive loss. Args: - outputs: feature dict, see class docstring. + embed_a: (B, dim) reconstructed (decoder) output of view a. + embed_b: (B, dim) reconstructed (decoder) output of view b. + embed_a_ori: (B, dim) original embedding of view a. + embed_b_ori: (B, dim) original embedding of view b. pair_mask: (B,) bool, True = contrastive-pair sample. + + Returns: + Tensor: scalar mean of the three contrastive terms (self/ori/cl). """ - embed_a = outputs["embed_a"] - embed_b = outputs["embed_b"] - embed_a_ori = outputs["embed_a_ori"] - embed_b_ori = outputs["embed_b_ori"] # The three contrastive temperatures, clamped (<= ln 100) then exp'd. - logit_scale = self._scaled(self.logit_scale) logit_scale_self = self._scaled(self.logit_scale_self) + logit_scale_ori = self._scaled(self.logit_scale_ori) logit_scale_cl = self._scaled(self.logit_scale_cl) local_batch_size = embed_a.size(0) @@ -156,54 +154,45 @@ def forward( embed_a = F.normalize(embed_a, dim=-1, p=2) embed_b = F.normalize(embed_b, dim=-1, p=2) - # All-gather across GPUs (with gradient support) - embed_a_all, embed_b_all = self._all_gather_with_grad([embed_a, embed_b]) - embed_a_all_ori, embed_b_all_ori = self._all_gather_with_grad( - [embed_a_ori, embed_b_ori] + # One batched all-gather for all four operands (gradient-preserving). + embed_a_all, embed_b_all, embed_a_all_ori, embed_b_all_ori = ( + self._all_gather_with_grad([embed_a, embed_b, embed_a_ori, embed_b_ori]) ) - # --- Compute six groups of logits (a/b × self/ori/cl) --- - logits_a_self = logit_scale_self * embed_a @ embed_b_all.t() - logits_b_self = logit_scale_self * embed_b @ embed_a_all.t() - - logits_a_ori = logit_scale * embed_a @ embed_b_all_ori.t() - logits_b_ori = logit_scale * embed_b @ embed_a_all_ori.t() - - logits_a_cl = logit_scale_cl * embed_a @ embed_a_all_ori.t() - logits_b_cl = logit_scale_cl * embed_b @ embed_b_all_ori.t() - - # Mask non-pair columns out of the negatives with the dtype's most negative - # finite value: below any real logit (masks like -inf), but finite so an - # all-non-pair row gives a finite CE/grad instead of 0*NaN. + # Column mask: drop non-pair columns from the negatives. pair_mask_all = self._gather_bool_mask(pair_mask) col_mask = (~pair_mask_all).unsqueeze(0) # (1, B_global) - neg_fill = torch.finfo(logits_a_self.dtype).min - logits_a_self = logits_a_self.masked_fill(col_mask, neg_fill) - logits_b_self = logits_b_self.masked_fill(col_mask, neg_fill) - logits_a_ori = logits_a_ori.masked_fill(col_mask, neg_fill) - logits_b_ori = logits_b_ori.masked_fill(col_mask, neg_fill) - logits_a_cl = logits_a_cl.masked_fill(col_mask, neg_fill) - logits_b_cl = logits_b_cl.masked_fill(col_mask, neg_fill) - - # --- Safe labels: non-pair rows fallback to the first pair column --- + # Safe labels: non-pair rows fall back to the first pair column. labels = self.labels fallback = pair_mask.long().argmax() # first pair sample index safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels)) - - # --- Masked CE for three loss groups (shared row mask + valid count) --- pair_mask_f = pair_mask.float() n_valid = pair_mask_f.sum().clamp(min=1) - loss_self = self._masked_cross_entropy( - logits_a_self, logits_b_self, safe_labels, pair_mask_f, n_valid - ) - loss_ori = self._masked_cross_entropy( - logits_a_ori, logits_b_ori, safe_labels, pair_mask_f, n_valid - ) - loss_cl = self._masked_cross_entropy( - logits_a_cl, logits_b_cl, safe_labels, pair_mask_f, n_valid - ) - - loss = (loss_self + loss_ori + loss_cl) / 3 - return {"loss": loss} + # Three symmetric contrastive groups, each (scale, a-target, b-target): + # self: recon-a vs recon-b (vs the other recon view) + # ori: recon vs the counterpart original + # cl: recon vs its own-view original + groups = ( + (logit_scale_self, embed_b_all, embed_a_all), + (logit_scale_ori, embed_b_all_ori, embed_a_all_ori), + (logit_scale_cl, embed_a_all_ori, embed_b_all_ori), + ) + loss = embed_a.new_zeros(()) + for scale, a_target, b_target in groups: + logits_a = scale * embed_a @ a_target.t() + logits_b = scale * embed_b @ b_target.t() + # Fill masked columns with the LOGITS dtype's most negative finite + # value: below any real logit (masks like -inf) but finite, so an + # all-non-pair row yields a finite CE/grad instead of 0*NaN. Derive + # it from the logits dtype, not the embeddings': under autocast the + # matmul casts to bf16/fp16 and finfo(embed.dtype=fp32).min would + # overflow masked_fill on the lower-precision logits. + neg_fill = torch.finfo(logits_a.dtype).min + logits_a = logits_a.masked_fill(col_mask, neg_fill) + logits_b = logits_b.masked_fill(col_mask, neg_fill) + loss = loss + self._masked_cross_entropy( + logits_a, logits_b, safe_labels, pair_mask_f, n_valid + ) + return loss / 3 diff --git a/tzrec/loss/infonce_loss_test.py b/tzrec/loss/sid_contrastive_loss_test.py similarity index 72% rename from tzrec/loss/infonce_loss_test.py rename to tzrec/loss/sid_contrastive_loss_test.py index bedd5cdad..65044e557 100644 --- a/tzrec/loss/infonce_loss_test.py +++ b/tzrec/loss/sid_contrastive_loss_test.py @@ -16,20 +16,20 @@ import torch.distributed as dist import torch.multiprocessing as mp -from tzrec.loss.infonce_loss import MaskedInfoNCELoss +from tzrec.loss.sid_contrastive_loss import SidContrastiveLoss from tzrec.utils import misc_util class AllGatherWithGradTest(unittest.TestCase): def test_single_process_identity(self) -> None: a, b = torch.randn(3, 4), torch.randn(3, 4) - out = MaskedInfoNCELoss._all_gather_with_grad([a, b]) + out = SidContrastiveLoss._all_gather_with_grad([a, b]) self.assertIs(out[0], a) self.assertIs(out[1], b) -class MaskedInfoNCELossTest(unittest.TestCase): - """Single-process tests for the masked CLIP loss.""" +class SidContrastiveLossTest(unittest.TestCase): + """Single-process tests for the masked pair-contrastive loss.""" def _features(self, B: int, D: int) -> dict: torch.manual_seed(0) @@ -40,52 +40,51 @@ def _features(self, B: int, D: int) -> dict: "embed_b_ori": torch.randn(B, D), } - def test_forward_all_clip_finite(self) -> None: - loss_fn = MaskedInfoNCELoss() + def test_forward_all_pairs_finite(self) -> None: + loss_fn = SidContrastiveLoss() feats = self._features(6, 8) mask = torch.ones(6, dtype=torch.bool) - out = loss_fn(feats, mask) - self.assertIn("loss", out) - self.assertTrue(torch.isfinite(out["loss"])) - self.assertGreater(out["loss"].item(), 0.0) + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) + self.assertGreater(loss.item(), 0.0) def test_all_recon_mask_zero_loss(self) -> None: - loss_fn = MaskedInfoNCELoss() + loss_fn = SidContrastiveLoss() feats = self._features(6, 8) - mask = torch.zeros(6, dtype=torch.bool) # no clip rows - out = loss_fn(feats, mask) - # No clip rows -> masked average is exactly zero (and finite). - self.assertTrue(torch.isfinite(out["loss"])) - self.assertAlmostEqual(out["loss"].item(), 0.0, places=6) + mask = torch.zeros(6, dtype=torch.bool) # no pair rows + loss = loss_fn(**feats, pair_mask=mask) + # No pair rows -> masked average is exactly zero (and finite). + self.assertTrue(torch.isfinite(loss)) + self.assertAlmostEqual(loss.item(), 0.0, places=6) def test_all_recon_mask_finite_gradient(self) -> None: # Regression: with float("-inf") column fill an all-recon batch produced # a NaN gradient (0 * NaN) that survived the row mask. The finite fill - # must keep the backward finite (and zero, since no clip row contributes). - loss_fn = MaskedInfoNCELoss() + # must keep the backward finite (and zero, since no pair row contributes). + loss_fn = SidContrastiveLoss() feats = self._features(6, 8) mask = torch.zeros(6, dtype=torch.bool) - loss_fn(feats, mask)["loss"].backward() + loss_fn(**feats, pair_mask=mask).backward() grad = feats["embed_a"].grad self.assertIsNotNone(grad) self.assertTrue(torch.isfinite(grad).all()) self.assertAlmostEqual(grad.abs().sum().item(), 0.0, places=6) def test_backward_flows_to_embeddings(self) -> None: - loss_fn = MaskedInfoNCELoss() + loss_fn = SidContrastiveLoss() feats = self._features(6, 8) mask = torch.ones(6, dtype=torch.bool) - loss_fn(feats, mask)["loss"].backward() + loss_fn(**feats, pair_mask=mask).backward() self.assertIsNotNone(feats["embed_a"].grad) self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) def test_recon_columns_excluded_from_negatives(self) -> None: - """A recon row's embedding must not affect a clip row's loss. + """A recon row's embedding must not affect a pair row's loss. Recon rows are dropped as queries (row mask) AND their columns are masked out of the negatives (col_mask). Perturbing the recon rows of EVERY column operand — ``embed_b`` (the self group) and both - ``*_ori`` operands (the ori/cl groups) — must leave the clip rows' loss + ``*_ori`` operands (the ori/cl groups) — must leave the pair rows' loss unchanged; a dropped or inverted ``col_mask`` on any group would fail. Distinct ``embed_a_ori`` / ``embed_b_ori`` so the ori/cl masking is actually exercised (not hidden by a shared tensor). @@ -104,25 +103,37 @@ def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): } txt, txt_ori, img_ori = (torch.randn(B, D) for _ in range(3)) - loss_fn = MaskedInfoNCELoss() + loss_fn = SidContrastiveLoss() loss_fn.eval() - base = loss_fn(feats(txt, txt_ori, img_ori), mask)["loss"] + base = loss_fn(**feats(txt, txt_ori, img_ori), pair_mask=mask) # Perturb ONLY the recon rows of every column operand that feeds negatives. txt2, txt_ori2, img_ori2 = txt.clone(), txt_ori.clone(), img_ori.clone() for t in (txt2, txt_ori2, img_ori2): t[2:] = torch.randn(2, D) - after = loss_fn(feats(txt2, txt_ori2, img_ori2), mask)["loss"] + after = loss_fn(**feats(txt2, txt_ori2, img_ori2), pair_mask=mask) torch.testing.assert_close(base, after) + def test_autocast_bf16_does_not_overflow_masked_fill(self) -> None: + # Regression: the column fill must come from the LOGITS dtype, not the + # (fp32) embeddings. Under autocast the matmul emits bf16, so + # finfo(fp32).min would raise "cannot be converted to BFloat16 without + # overflow" in masked_fill. A mixed pair/non-pair mask exercises the fill. + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) + with torch.autocast("cpu", dtype=torch.bfloat16): + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) + def test_mask_holds_under_large_scale(self) -> None: # The column fill is finfo.min (below any real logit) rather than a # hardcoded -1e4, so masking holds even when the temperature is large and # the *_ori operands are un-normalized (real logits can dwarf 1e4). The # loss's internal clamp caps exp() at <= 100; loss/grad must stay finite. - loss_fn = MaskedInfoNCELoss() + loss_fn = SidContrastiveLoss() with torch.no_grad(): for p in ( - loss_fn.logit_scale, + loss_fn.logit_scale_ori, loss_fn.logit_scale_self, loss_fn.logit_scale_cl, ): @@ -132,17 +143,17 @@ def test_mask_holds_under_large_scale(self) -> None: feats["embed_a_ori"] = feats["embed_a_ori"] * 50 feats["embed_b_ori"] = feats["embed_b_ori"] * 50 mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) - out = loss_fn(feats, mask) - self.assertTrue(torch.isfinite(out["loss"])) + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) loss_fn.train() feats["embed_a"].grad = None - loss_fn(feats, mask)["loss"].backward() + loss_fn(**feats, pair_mask=mask).backward() self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) -# --- Multi-process tests for the CLIP distributed all-gather path. --- +# --- Multi-process tests for the contrastive distributed all-gather path. --- # Validates ``_all_gather_with_grad`` (built on the differentiable -# ``torch.distributed.nn.functional.all_gather``) and ``MaskedInfoNCELoss`` across +# ``torch.distributed.nn.functional.all_gather``) and ``SidContrastiveLoss`` across # ranks. Uses NCCL on GPU when >=2 devices are available (the production path the # reviewer cared about), else falls back to gloo/CPU, so it runs on a multi-GPU # box and in CPU CI alike. @@ -168,7 +179,7 @@ def _all_gather_worker(rank: int, world_size: int, port: int) -> None: device = _init(rank, world_size, port) # Each rank holds a distinct, rank-identifying tensor. x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) - gathered = MaskedInfoNCELoss._all_gather_with_grad([x])[0] + gathered = SidContrastiveLoss._all_gather_with_grad([x])[0] # Forward: gathered is (world_size*2, 3); rank r contributes rows # [2r : 2r+2] all equal to (r+1). @@ -190,7 +201,7 @@ def _all_gather_worker(rank: int, world_size: int, port: int) -> None: dist.destroy_process_group() -def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: +def _contrastive_worker(rank: int, world_size: int, port: int) -> None: device = _init(rank, world_size, port) torch.manual_seed(1234 + rank) B, D = 4, 8 @@ -202,13 +213,12 @@ def _masked_clip_worker(rank: int, world_size: int, port: int) -> None: } mask = torch.ones(B, dtype=torch.bool, device=device) - loss_fn = MaskedInfoNCELoss().to(device) - out = loss_fn(feats, mask) - clip_loss = out["loss"] - assert torch.isfinite(clip_loss).all(), f"rank{rank}: non-finite clip_loss" - assert clip_loss.item() > 0.0, f"rank{rank}: clip_loss not positive" + loss_fn = SidContrastiveLoss().to(device) + loss = loss_fn(**feats, pair_mask=mask) + assert torch.isfinite(loss).all(), f"rank{rank}: non-finite loss" + assert loss.item() > 0.0, f"rank{rank}: loss not positive" - clip_loss.backward() + loss.backward() g = feats["embed_a"].grad assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" dist.destroy_process_group() @@ -228,14 +238,14 @@ def _run(target) -> None: raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") -class InfoNCEDistTest(unittest.TestCase): - """2-rank tests for the CLIP distributed collectives.""" +class SidContrastiveDistTest(unittest.TestCase): + """2-rank tests for the contrastive distributed collectives.""" def test_all_gather_with_grad(self) -> None: _run(_all_gather_worker) - def test_masked_clip_loss(self) -> None: - _run(_masked_clip_worker) + def test_masked_contrastive_loss(self) -> None: + _run(_contrastive_worker) if __name__ == "__main__": diff --git a/tzrec/loss/sid_recon_loss.py b/tzrec/loss/sid_recon_loss.py new file mode 100644 index 000000000..c6ba68fdc --- /dev/null +++ b/tzrec/loss/sid_recon_loss.py @@ -0,0 +1,54 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidReconLoss: per-row RQ-VAE reconstruction distance (input vs. decoder).""" + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + + +class SidReconLoss(_Loss): + """Per-row reconstruction distance for the configured ``recon_type``. + + ``forward(x_hat, x)`` returns the per-row distance ``(B,)`` + (``reduction="none"``); the model reduces it (a masked mean over the + reconstruction rows). Registered as a ``_loss_modules`` entry alongside the + commitment / contrastive losses. + + Args: + recon_type (str): the distance, ``"l2"`` (mse), ``"l1"`` or ``"cos"``. + Default: ``"l2"``. + """ + + def __init__(self, recon_type: str = "l2") -> None: + super().__init__() + if recon_type not in ("l2", "l1", "cos"): + raise ValueError( + f"recon_type must be 'l2', 'l1' or 'cos', got {recon_type!r}" + ) + self.recon_type = recon_type + + def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Per-row reconstruction distance. + + Args: + x_hat (Tensor): reconstruction (decoder output), shape (B, D). + x (Tensor): the input it reconstructs, shape (B, D). + + Returns: + Tensor: per-row distance, shape (B,). + """ + if self.recon_type == "l1": + return F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + if self.recon_type == "cos": + return 1 - F.cosine_similarity(x_hat, x, dim=-1) + return F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) # "l2" diff --git a/tzrec/loss/sid_recon_loss_test.py b/tzrec/loss/sid_recon_loss_test.py new file mode 100644 index 000000000..eafa37f51 --- /dev/null +++ b/tzrec/loss/sid_recon_loss_test.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.loss.sid_recon_loss import SidReconLoss + + +class SidReconLossTest(unittest.TestCase): + """Tests for the per-row reconstruction-distance module.""" + + def test_l2_is_per_row_mse(self) -> None: + d = SidReconLoss("l2")(torch.ones(3, 4), torch.zeros(3, 4)) + self.assertEqual(d.shape, (3,)) + torch.testing.assert_close(d, torch.ones(3)) # mean of 1^2 over dim -1 + + def test_l1_is_per_row_mae(self) -> None: + d = SidReconLoss("l1")(torch.ones(2, 5), torch.zeros(2, 5)) + torch.testing.assert_close(d, torch.ones(2)) + + def test_cos_is_one_minus_cosine(self) -> None: + x = torch.tensor([[1.0, 0.0]]) + # identical vectors -> cosine 1 -> distance 0 + d = SidReconLoss("cos")(x, x.clone()) + torch.testing.assert_close(d, torch.zeros(1), atol=1e-6, rtol=0) + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_each_type_finite_and_backprops(self, recon_type) -> None: + x_hat = torch.randn(4, 6, requires_grad=True) + loss = SidReconLoss(recon_type)(x_hat, torch.randn(4, 6)).mean() + self.assertTrue(torch.isfinite(loss)) + loss.backward() # grad must flow back to the (decoder) input + self.assertIsNotNone(x_hat.grad) + + def test_unknown_type_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "recon_type"): + SidReconLoss("nope") + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index c87e5fa67..44c5e80ec 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -11,16 +11,16 @@ """BaseSidModel: shared base for semantic-ID generation models.""" -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import torch -import torch.nn.functional as F import torchmetrics from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature -from tzrec.loss.commitment_loss import CommitmentLoss -from tzrec.loss.infonce_loss import MaskedInfoNCELoss +from tzrec.loss.sid_commitment_loss import SidCommitmentLoss +from tzrec.loss.sid_contrastive_loss import SidContrastiveLoss +from tzrec.loss.sid_recon_loss import SidReconLoss from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel @@ -30,34 +30,14 @@ from tzrec.protos.model_pb2 import ModelConfig -def recon_loss( - recon_type: str, -) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: - """Per-row reconstruction-distance fn for the configured ``recon_type``. - - Args: - recon_type (str): the distance, ``"l2"`` (mse), ``"l1"`` or ``"cos"``. - - Returns: - Callable: ``f(x_hat, x) -> (B,)`` per-row reconstruction distance. - """ - if recon_type == "l2": - return lambda x_hat, x: F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) - if recon_type == "l1": - return lambda x_hat, x: F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - if recon_type == "cos": - return lambda x_hat, x: 1 - F.cosine_similarity(x_hat, x, dim=-1) - raise ValueError(f"recon_type must be 'l2', 'l1' or 'cos', got {recon_type!r}") - - def _masked_mean( per_sample: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """Mean of a per-row loss over the masked-in rows (all rows if ``mask`` None). - The mixed recon+CLIP path applies the reconstruction loss to recon rows only; - the masked mean divides by the valid-row count (``div_no_nan`` keeps an empty - mask at 0). No data-dependent branching → ``torch.compile``-friendly. + The mixed recon+contrastive path applies the reconstruction loss to recon rows + only; the masked mean divides by the valid-row count (``div_no_nan`` keeps an + empty mask at 0). No data-dependent branching → ``torch.compile``-friendly. Args: per_sample (Tensor): per-row loss, shape (B,). @@ -112,7 +92,6 @@ def __init__( cfg = self._model_config # Config fields shared by every SID proto message. - self._feature_group = cfg.feature_group self._normalize_residuals = cfg.normalize_residuals if not cfg.codebook: @@ -130,11 +109,7 @@ def __init__( # from the main group's total dim (which may concatenate several # content + side-info features). self.init_input() - if not self.embedding_group.has_group(self._feature_group): - raise ValueError( - f"feature_group {self._feature_group!r} is not in " - f"model_config.feature_groups {self.embedding_group.group_names()}" - ) + self._feature_group = self._resolve_feature_group() self._input_dim = self.embedding_group.group_total_dim(self._feature_group) if self._input_dim < 1: raise ValueError( @@ -150,36 +125,57 @@ def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: """Build grouped input features: ``{group_name: (B, group_total_dim)}``.""" return self.embedding_group(batch) + def _resolve_feature_group(self) -> str: + """Resolve the main input feature group name. + + Uses ``feature_group`` when set; otherwise, when exactly one group is + declared, that sole group (DLRM-style auto-detect); otherwise fails as + ambiguous. The resolved name must exist in the model's feature groups. + """ + groups = self.embedding_group.group_names() + if self._model_config.HasField("feature_group"): + name = self._model_config.feature_group + if name not in groups: + raise ValueError( + f"feature_group {name!r} is not in model_config.feature_groups " + f"{groups}" + ) + return name + if len(groups) == 1: + return groups[0] + raise ValueError( + "feature_group must be set when multiple feature_groups are declared, " + f"got groups {groups}" + ) + def init_loss(self) -> None: """Initialize SID loss modules from ``ModelConfig.losses``. Each ``LossConfig`` sets one ``sid_loss`` oneof variant (a reconstruction - loss, the commitment loss, or the CLIP loss). Mirrors ``RankModel``: the - config drives what is bound here, and :meth:`loss` computes them from - ``predictions``. The reconstruction loss binds a per-row distance fn into - ``_recon_fn``; commitment/CLIP register modules into ``_loss_modules``. + loss, the commitment loss, or the contrastive loss). Mirrors ``RankModel``: + the config drives what is registered here, and :meth:`loss` computes them + from ``predictions``. All three are registered as ``_loss_modules`` entries. """ - self._recon_fn: Optional[ - Callable[[torch.Tensor, torch.Tensor], torch.Tensor] - ] = None for loss_cfg in self._base_model_config.losses: self._init_sid_loss_impl(loss_cfg) def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: - """Bind the loss (a recon fn or a module) for one ``sid_loss`` config.""" + """Register the loss module for one ``sid_loss`` config.""" loss_type = loss_cfg.WhichOneof("sid_loss") if loss_type == "recon_loss": - self._recon_fn = recon_loss(loss_cfg.recon_loss.recon_type) + self._loss_modules["recon_loss"] = SidReconLoss( + loss_cfg.recon_loss.recon_type + ) elif loss_type == "commitment_loss": cfg = loss_cfg.commitment_loss latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) - self._loss_modules["commitment_loss"] = CommitmentLoss( + self._loss_modules["commitment_loss"] = SidCommitmentLoss( latent_weight=latent_weight, commitment_type=cfg.commitment_type, ) - elif loss_type == "sid_clip_loss": - # The InfoNCE module owns its learnable contrastive temperatures. - self._loss_modules["sid_clip_loss"] = MaskedInfoNCELoss() + elif loss_type == "contrastive_loss": + # The contrastive module owns its learnable temperatures. + self._loss_modules["contrastive_loss"] = SidContrastiveLoss() else: raise ValueError( f"LossConfig for a SID model must set a sid_loss variant, " @@ -194,7 +190,8 @@ def loss( Args: predictions (dict): a dict of predicted result (the raw tensors the losses consume — ``x_hat``/``recon_target`` for reconstruction, - ``encoder_out``/``latents`` for commitment, and the CLIP embeds). + ``encoder_out``/``latents`` for commitment, and the contrastive + embeds). batch (Batch): input batch data. Return: @@ -211,7 +208,7 @@ def _sid_loss_impl( """Compute one ``sid_loss`` term from ``predictions``.""" loss_type = loss_cfg.WhichOneof("sid_loss") if loss_type == "recon_loss": - per_sample = self._recon_fn( + per_sample = self._loss_modules["recon_loss"]( predictions["x_hat"], predictions["recon_target"] ) return { @@ -222,15 +219,15 @@ def _sid_loss_impl( predictions["encoder_out"], predictions["latents"] ) return {"commitment_loss": loss} - elif loss_type == "sid_clip_loss": - feats = { - "embed_a": predictions["embed_a"], - "embed_b": predictions["embed_b"], - "embed_a_ori": predictions["embed_a_ori"], - "embed_b_ori": predictions["embed_b_ori"], - } - out = self._loss_modules["sid_clip_loss"](feats, predictions["pair_mask"]) - return {"sid_clip_loss": out["loss"]} + elif loss_type == "contrastive_loss": + loss = self._loss_modules["contrastive_loss"]( + predictions["embed_a"], + predictions["embed_b"], + predictions["embed_a_ori"], + predictions["embed_b_ori"], + predictions["pair_mask"], + ) + return {"contrastive_loss": loss} else: raise ValueError(f"unsupported sid_loss variant: {loss_type!r}") @@ -266,7 +263,7 @@ def update_metric( Subclasses expose both only when meaningful, so a not-yet-fitted model omits them and this logs nothing. (Reading the target from ``predictions`` avoids a second ``build_input`` pass over ``batch``.) - For the mixed CLIP path the reconstruction is scored only on the + For the mixed contrastive path the reconstruction is scored only on the non-pair rows (``recon_mask``), matching the masked training recon loss so the eval mse/rel_loss stay comparable to the optimized objective. diff --git a/tzrec/models/sid_model_test.py b/tzrec/models/sid_model_test.py index 2f632bbb3..f4eadcf3d 100644 --- a/tzrec/models/sid_model_test.py +++ b/tzrec/models/sid_model_test.py @@ -12,40 +12,8 @@ import unittest import torch -from parameterized import parameterized -from tzrec.models.sid_model import _masked_mean, recon_loss - - -class ReconLossTest(unittest.TestCase): - """Tests for the shared ``BaseSidModel`` reconstruction-distance factory.""" - - def test_l2_is_per_row_mse(self) -> None: - d = recon_loss("l2")(torch.ones(3, 4), torch.zeros(3, 4)) - self.assertEqual(d.shape, (3,)) - torch.testing.assert_close(d, torch.ones(3)) # mean of 1^2 over dim -1 - - def test_l1_is_per_row_mae(self) -> None: - d = recon_loss("l1")(torch.ones(2, 5), torch.zeros(2, 5)) - torch.testing.assert_close(d, torch.ones(2)) - - def test_cos_is_one_minus_cosine(self) -> None: - x = torch.tensor([[1.0, 0.0]]) - # identical vectors -> cosine 1 -> distance 0 - d = recon_loss("cos")(x, x.clone()) - torch.testing.assert_close(d, torch.zeros(1), atol=1e-6, rtol=0) - - @parameterized.expand([("l2",), ("l1",), ("cos",)]) - def test_each_type_finite_and_backprops(self, recon_type) -> None: - x_hat = torch.randn(4, 6, requires_grad=True) - loss = recon_loss(recon_type)(x_hat, torch.randn(4, 6)).mean() - self.assertTrue(torch.isfinite(loss)) - loss.backward() # grad must flow back to the (decoder) input - self.assertIsNotNone(x_hat.grad) - - def test_unknown_type_raises(self) -> None: - with self.assertRaisesRegex(ValueError, "recon_type"): - recon_loss("nope") +from tzrec.models.sid_model import _masked_mean class MaskedMeanTest(unittest.TestCase): diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index 9a3d69d86..efdd139e6 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -12,7 +12,7 @@ """SidRqvae: SID generation model using RQ-VAE (Encoder + VQ + Decoder). End-to-end differentiable training. The reconstruction, commitment and optional -CLIP contrastive losses are configured via ``ModelConfig.losses`` (the +contrastive losses are configured via ``ModelConfig.losses`` (the ``LossConfig`` ``sid_loss`` oneof) and computed centrally in :meth:`BaseSidModel.loss`; :meth:`predict` only produces the raw tensors those losses consume. The encoder/decoder and residual vector quantizer live directly @@ -46,8 +46,8 @@ class SidRqvae(BaseSidModel): (ReLU between hidden layers; the decoder mirrors the encoder.) Losses are config-driven (``ModelConfig.losses`` / ``sid_loss`` oneof). When a - ``sid_clip_loss`` is configured, ``predict`` runs a dual (image/text) path and - the masked CLIP contrastive loss is applied to the CLIP-pair rows. + ``contrastive_loss`` is configured, ``predict`` runs a dual (paired) path and + the masked contrastive loss is applied to the contrastive-pair rows. Args: model_config (ModelConfig): an instance of ModelConfig. @@ -68,7 +68,7 @@ def __init__( cfg = self._model_config # SidRqvae proto message - self._init_clip() + self._init_contrastive() embed_dim = cfg.embed_dim # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero @@ -115,61 +115,64 @@ def __init__( logger.info( "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " - "n_layers=%d, n_embed=%s, use_clip=%s", + "n_layers=%d, n_embed=%s, use_contrastive=%s", self._input_dim, embed_dim, hidden_dims, self._n_layers, self._n_embed_list, - self._use_clip, + self._use_contrastive, ) - def _init_clip(self) -> None: - """Read and validate the CLIP dual-encoder wiring (``clip_config``). + def _init_contrastive(self) -> None: + """Read and validate the pair-contrastive wiring (``contrastive_config``). - Sets ``_use_clip`` and the paired / pair-flag group names, and enforces: - ``clip_config`` (structure) and a ``sid_clip_loss`` entry (objective) are - set together; the paired group exists and matches ``input_dim`` (it shares - the encoder); the pair-flag group is a single dim-1 raw flag. Must run - after ``super().__init__()`` — it needs ``embedding_group`` / ``_input_dim``. + Sets ``_use_contrastive`` and the paired / pair-flag group names, and + enforces: ``contrastive_config`` (structure) and a ``contrastive_loss`` + entry (objective) are set together; the paired group exists and matches + ``input_dim`` (it shares the encoder); the pair-flag group is a single + dim-1 raw flag. Must run after ``super().__init__()`` — it needs + ``embedding_group`` / ``_input_dim``. """ cfg = self._model_config - # Default to no CLIP; the group names stay None unless clip_config is set. - self._clip_feature_group = None - self._clip_pair_feature_group = None - self._use_clip = cfg.HasField("clip_config") - has_clip_obj = any( - lc.WhichOneof("sid_loss") == "sid_clip_loss" + # Default to no contrastive path; the group names stay None unless + # contrastive_config is set. + self._pair_feature_group = None + self._pair_flag_feature_group = None + self._use_contrastive = cfg.HasField("contrastive_config") + has_contrastive_obj = any( + lc.WhichOneof("sid_loss") == "contrastive_loss" for lc in self._base_model_config.losses ) - if self._use_clip != has_clip_obj: + if self._use_contrastive != has_contrastive_obj: raise ValueError( - "clip_config (model structure) and a sid_clip_loss entry in " - "losses (the objective) must be set together; got " - f"clip_config={self._use_clip}, sid_clip_loss={has_clip_obj}" + "contrastive_config (model structure) and a contrastive_loss entry " + "in losses (the objective) must be set together; got " + f"contrastive_config={self._use_contrastive}, " + f"contrastive_loss={has_contrastive_obj}" ) - if not self._use_clip: + if not self._use_contrastive: return - self._clip_feature_group = cfg.clip_config.clip_feature_group - self._clip_pair_feature_group = cfg.clip_config.clip_pair_feature_group - for grp in (self._clip_feature_group, self._clip_pair_feature_group): + self._pair_feature_group = cfg.contrastive_config.pair_feature_group + self._pair_flag_feature_group = cfg.contrastive_config.pair_flag_feature_group + for grp in (self._pair_feature_group, self._pair_flag_feature_group): if not self.embedding_group.has_group(grp): raise ValueError( - f"clip group {grp!r} is not in model_config.feature_groups " - f"{self.embedding_group.group_names()}" + f"contrastive group {grp!r} is not in model_config.feature_groups" + f" {self.embedding_group.group_names()}" ) - clip_dim = self.embedding_group.group_total_dim(self._clip_feature_group) - if clip_dim != self._input_dim: + pair_dim = self.embedding_group.group_total_dim(self._pair_feature_group) + if pair_dim != self._input_dim: raise ValueError( - f"clip_feature_group {self._clip_feature_group!r} has total " - f"dim {clip_dim}, but it is encoded by the same encoder as the " + f"pair_feature_group {self._pair_feature_group!r} has total " + f"dim {pair_dim}, but it is encoded by the same encoder as the " f"main feature_group (dim {self._input_dim}); the two must match" ) - pair_dim = self.embedding_group.group_total_dim(self._clip_pair_feature_group) - if pair_dim != 1: + flag_dim = self.embedding_group.group_total_dim(self._pair_flag_feature_group) + if flag_dim != 1: raise ValueError( - f"clip_pair_feature_group {self._clip_pair_feature_group!r} must " - f"be a single dim-1 raw flag, got total dim {pair_dim}" + f"pair_flag_feature_group {self._pair_flag_feature_group!r} must " + f"be a single dim-1 raw flag, got total dim {flag_dim}" ) def _encode(self, x: torch.Tensor) -> torch.Tensor: @@ -198,7 +201,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: # Codes-only path: get_codes does just the residual walk (no decode, # no commitment latents), so neither dual-path branch is needed. return {"codes": self._quantizer.get_codes(self._encode(embedding))} - if self._use_clip: + if self._use_contrastive: return self._predict_mixed(grouped) return self._predict_rqvae(embedding) @@ -229,19 +232,19 @@ def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: def _predict_mixed( self, grouped: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: - """Mixed recon + CLIP: dual path over the main + paired feature groups. + """Mixed recon + contrastive: dual path over the main + paired groups. ``encoder_out`` / ``latents`` stack both paths so the commitment loss - averages over them; ``recon_mask`` (= non-CLIP rows) restricts the recon + averages over them; ``recon_mask`` (= non-pair rows) restricts the recon loss to reconstruction-only rows. Args: grouped (dict): the EmbeddingGroup output (group name -> tensor). """ embedding = grouped[self._feature_group] - fea2 = grouped[self._clip_feature_group] - is_clip_pair_raw = grouped[self._clip_pair_feature_group] - clip_mask = is_clip_pair_raw.view(is_clip_pair_raw.shape[0], -1)[:, 0] > 0.5 + fea2 = grouped[self._pair_feature_group] + is_pair_raw = grouped[self._pair_flag_feature_group] + pair_mask = is_pair_raw.view(is_pair_raw.shape[0], -1)[:, 0] > 0.5 z_e1, quant1, x_hat1 = self._rqvae_pass(embedding) z_e2, quant2, x_hat2 = self._rqvae_pass(fea2) @@ -250,7 +253,7 @@ def _predict_mixed( "codes": quant1.cluster_ids, "x_hat": x_hat1, "recon_target": embedding, - "recon_mask": ~clip_mask, + "recon_mask": ~pair_mask, "encoder_out": torch.cat([z_e1, z_e2], dim=0), "latents": torch.cat([quant1.latents, quant2.latents], dim=0), # generic contrastive operands (view a = main, view b = paired): @@ -258,5 +261,5 @@ def _predict_mixed( "embed_b": x_hat2, "embed_a_ori": embedding, "embed_b_ori": fea2, - "pair_mask": clip_mask, + "pair_mask": pair_mask, } diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index 8bb5c7a54..de0ac8757 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -24,16 +24,17 @@ def _features_and_groups( - input_dim: int, use_clip: bool = False, clip_dim: int = None, pair_dim: int = 1 + input_dim: int, use_contrastive: bool = False, pair_emb_dim: int = None, flag_dim=1 ): """Real raw features + feature groups for a SID model. Mirrors how every other model test wires inputs: ``create_features`` builds the ``BaseFeature`` objects and ``feature_groups`` (consumed by the model's - :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for CLIP, the - paired image group and the per-row pair-flag group. ``clip_dim`` (default: - match ``input_dim``) sizes the paired group and ``pair_dim`` (default 1) - sizes the pair-flag group, so a test can deliberately mismatch either. + :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for the + contrastive path, the paired group and the per-row pair-flag group. + ``pair_emb_dim`` (default: match ``input_dim``) sizes the paired group and + ``flag_dim`` (default 1) sizes the pair-flag group, so a test can + deliberately mismatch either. """ def _raw(name: str, dim: int) -> feature_pb2.FeatureConfig: @@ -50,12 +51,12 @@ def _deep(group_name: str, feature_name: str) -> model_pb2.FeatureGroupConfig: feature_cfgs = [_raw("item_emb", input_dim)] groups = [_deep("deep", "item_emb")] - if use_clip: + if use_contrastive: feature_cfgs += [ - _raw("image_emb", clip_dim if clip_dim is not None else input_dim), - _raw("is_clip_pair", pair_dim), + _raw("pair_emb", pair_emb_dim if pair_emb_dim is not None else input_dim), + _raw("is_pair", flag_dim), ] - groups += [_deep("clip_image", "image_emb"), _deep("clip_pair", "is_clip_pair")] + groups += [_deep("pair", "pair_emb"), _deep("pair_flag", "is_pair")] return create_features(feature_cfgs), groups @@ -87,11 +88,11 @@ def _commitment_cfg( return lc -def _clip_cfg() -> loss_pb2.LossConfig: +def _contrastive_cfg() -> loss_pb2.LossConfig: # The contrastive objective marker (empty); the paired-feature wiring lives - # on the model proto (SidRqvae.clip_config), set in _create_model. + # on the model proto (SidRqvae.contrastive_config), set in _create_model. lc = loss_pb2.LossConfig() - lc.sid_clip_loss.SetInParent() + lc.contrastive_loss.SetInParent() return lc @@ -100,7 +101,7 @@ class SidRqvaeTest(unittest.TestCase): def _create_model( self, - use_clip=False, + use_contrastive=False, input_dim=32, embed_dim=8, n_layers=2, @@ -115,13 +116,15 @@ def _create_model( kmeans_init=False, ) losses = [_recon_loss_cfg(recon), _commitment_cfg()] - if use_clip: - sid_rqvae_cfg.clip_config.clip_feature_group = "clip_image" - sid_rqvae_cfg.clip_config.clip_pair_feature_group = "clip_pair" - losses.append(_clip_cfg()) + if use_contrastive: + # Multiple feature groups -> the main group must be named explicitly. + sid_rqvae_cfg.feature_group = "deep" + sid_rqvae_cfg.contrastive_config.pair_feature_group = "pair" + sid_rqvae_cfg.contrastive_config.pair_flag_feature_group = "pair_flag" + losses.append(_contrastive_cfg()) # Real features + feature_groups: input_dim is derived from the group. - features, feature_groups = _features_and_groups(input_dim, use_clip) + features, feature_groups = _features_and_groups(input_dim, use_contrastive) model_config = model_pb2.ModelConfig( feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, losses=losses ) @@ -129,15 +132,15 @@ def _create_model( init_parameters(model, device=torch.device("cpu")) return model - def _clip_batch(self, B, input_dim, is_clip_pair): + def _contrastive_batch(self, B, input_dim, is_pair): return Batch( dense_features={ BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], + keys=["item_emb", "pair_emb", "is_pair"], tensors=[ torch.randn(B, input_dim), torch.randn(B, input_dim), - is_clip_pair, + is_pair, ], ) }, @@ -202,16 +205,16 @@ def test_rqvae_inference_mode(self) -> None: self.assertNotIn("x_hat", predictions) self.assertNotIn("latents", predictions) - def test_rqvae_clip_mode(self) -> None: - """Test SidRqvae with CLIP mixed mode (mixed recon + clip batch).""" + def test_rqvae_contrastive_mode(self) -> None: + """Test SidRqvae with the mixed recon + contrastive path.""" B, input_dim = 8, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.train() model.init_loss() - is_clip_pair = torch.zeros(B, 1) - is_clip_pair[B // 2 :] = 1.0 # second half clip - batch = self._clip_batch(B, input_dim, is_clip_pair) + is_pair = torch.zeros(B, 1) + is_pair[B // 2 :] = 1.0 # second half are contrastive pairs + batch = self._contrastive_batch(B, input_dim, is_pair) predictions = model.predict(batch) self.assertIn("codes", predictions) @@ -222,7 +225,7 @@ def test_rqvae_clip_mode(self) -> None: losses = model.loss(predictions, batch) self.assertIn("recon_loss", losses) self.assertIn("commitment_loss", losses) - self.assertIn("sid_clip_loss", losses) + self.assertIn("contrastive_loss", losses) total_loss = sum(losses.values()) self.assertTrue(total_loss.requires_grad) @@ -232,29 +235,29 @@ def test_rqvae_clip_mode(self) -> None: ) self.assertTrue(has_grad) - def test_rqvae_clip_all_recon(self) -> None: - """Mixed mode with all-recon batch: clip term 0, recon term > 0.""" + def test_rqvae_contrastive_all_recon(self) -> None: + """Mixed mode, all-recon batch: contrastive term 0, recon term > 0.""" B, input_dim = 4, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.train() model.init_loss() - batch = self._clip_batch(B, input_dim, torch.zeros(B, 1)) + batch = self._contrastive_batch(B, input_dim, torch.zeros(B, 1)) losses = model.loss(model.predict(batch), batch) - self.assertEqual(losses["sid_clip_loss"].item(), 0.0) + self.assertEqual(losses["contrastive_loss"].item(), 0.0) self.assertGreater(losses["recon_loss"].item(), 0.0) - def test_rqvae_clip_all_clip(self) -> None: - """Mixed mode with all-clip batch: recon term 0, clip term > 0.""" + def test_rqvae_contrastive_all_pair(self) -> None: + """Mixed mode, all-pair batch: recon term 0, contrastive term > 0.""" B, input_dim = 4, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.train() model.init_loss() - batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) + batch = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) self.assertEqual(losses["recon_loss"].item(), 0.0) - self.assertGreater(losses["sid_clip_loss"].item(), 0.0) + self.assertGreater(losses["contrastive_loss"].item(), 0.0) def test_rqvae_backward(self) -> None: """Test that backward pass works without errors.""" @@ -288,78 +291,85 @@ def test_commitment_latent_weight_wrong_length_raises(self) -> None: with self.assertRaisesRegex(ValueError, "latent_weight"): model.init_loss() - def test_clip_feature_group_dim_mismatch_raises(self) -> None: - """A CLIP paired group whose dim != the main group fails fast at init. + def test_pair_feature_group_dim_mismatch_raises(self) -> None: + """A paired group whose dim != the main group fails fast at init. The paired feature is encoded by the same encoder as the main input, so a dim mismatch would otherwise crash with an opaque matmul shape error on the first contrastive forward — not at construction. """ - features, feature_groups = _features_and_groups(32, use_clip=True, clip_dim=16) + features, feature_groups = _features_and_groups( + 32, use_contrastive=True, pair_emb_dim=16 + ) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.clip_config.clip_feature_group = "clip_image" - cfg.clip_config.clip_pair_feature_group = "clip_pair" + cfg.feature_group = "deep" + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flag" model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] ) with self.assertRaisesRegex(ValueError, "must match"): SidRqvae(model_config=model_config, features=features, labels=[]) - def test_clip_pair_group_must_be_dim_1(self) -> None: + def test_pair_flag_group_must_be_dim_1(self) -> None: """A pair-flag group with dim != 1 fails fast (would mis-route rows).""" - features, feature_groups = _features_and_groups(32, use_clip=True, pair_dim=3) + features, feature_groups = _features_and_groups( + 32, use_contrastive=True, flag_dim=3 + ) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.clip_config.clip_feature_group = "clip_image" - cfg.clip_config.clip_pair_feature_group = "clip_pair" + cfg.feature_group = "deep" + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flag" model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] ) with self.assertRaisesRegex(ValueError, "dim-1 raw flag"): SidRqvae(model_config=model_config, features=features, labels=[]) - def test_clip_group_missing_raises(self) -> None: - """A typo'd clip group name fails fast at init, not a forward KeyError.""" - features, feature_groups = _features_and_groups(32, use_clip=True) + def test_contrastive_group_missing_raises(self) -> None: + """A typo'd contrastive group name fails fast at init, not on forward.""" + features, feature_groups = _features_and_groups(32, use_contrastive=True) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.clip_config.clip_feature_group = "clip_image" - cfg.clip_config.clip_pair_feature_group = "clip_pairTYPO" + cfg.feature_group = "deep" + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flagTYPO" model_config = model_pb2.ModelConfig( - feature_groups=feature_groups, sid_rqvae=cfg, losses=[_clip_cfg()] + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] ) with self.assertRaisesRegex(ValueError, "not in model_config.feature_groups"): SidRqvae(model_config=model_config, features=features, labels=[]) - def test_eval_metric_masks_clip_pair_rows(self) -> None: - """CLIP eval mse/rel_loss score only the non-pair (recon) rows. + def test_eval_metric_masks_contrastive_pair_rows(self) -> None: + """Contrastive eval mse/rel_loss score only the non-pair (recon) rows. Training masks the recon loss to non-pair rows; update_metric must apply the same ``recon_mask`` so the eval metric stays comparable (pair rows, which the decoder is not trained to reconstruct, must not dilute it). """ B, input_dim = 8, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.eval() model.init_metric() # All-pair batch: recon_mask selects zero rows, so mse observes none. - all_pair = self._clip_batch(B, input_dim, torch.ones(B, 1)) + all_pair = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) model.update_metric(model.predict(all_pair), all_pair) self.assertEqual(model._metric_modules["mse"].total.item(), 0.0) # A recon (non-pair) batch then contributes rows. - all_recon = self._clip_batch(B, input_dim, torch.zeros(B, 1)) + all_recon = self._contrastive_batch(B, input_dim, torch.zeros(B, 1)) model.update_metric(model.predict(all_recon), all_recon) self.assertGreater(model._metric_modules["mse"].total.item(), 0.0) - def test_clip_mask_uses_flag_not_equality(self) -> None: - """The is_clip_pair flag, not bit-exact equality, drives routing. + def test_pair_flag_drives_routing_not_equality(self) -> None: + """The is_pair flag, not bit-exact equality, drives routing. - Build a batch where ``image_emb == item_emb`` numerically but - ``is_clip_pair=1``: rows must route to the CLIP branch (under the old + Build a batch where ``pair_emb == item_emb`` numerically but + ``is_pair=1``: rows must route to the contrastive branch (under the old bit-exact logic they would have been silently relabeled recon). """ B, input_dim = 4, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.train() model.init_loss() @@ -367,7 +377,7 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: batch = Batch( dense_features={ BASE_DATA_GROUP: KeyedTensor.from_tensor_list( - keys=["item_emb", "image_emb", "is_clip_pair"], + keys=["item_emb", "pair_emb", "is_pair"], tensors=[item_emb, item_emb.clone(), torch.ones(B, 1)], ) }, @@ -376,7 +386,7 @@ def test_clip_mask_uses_flag_not_equality(self) -> None: ) losses = model.loss(model.predict(batch), batch) self.assertEqual(losses["recon_loss"].item(), 0.0) - self.assertGreater(losses["sid_clip_loss"].item(), 0.0) + self.assertGreater(losses["contrastive_loss"].item(), 0.0) @parameterized.expand( [ @@ -424,24 +434,24 @@ def test_recon_type_branch(self, recon_type) -> None: def test_logit_scale_clamped_prevents_overflow(self) -> None: """A raw logit_scale far above ln(100) must not overflow. - The clamp caps ``exp()`` so the CLIP loss and the parameter gradient - stay finite; without it, ``exp(large)`` -> +Inf -> a NaN gradient that - permanently corrupts the parameter. + The clamp caps ``exp()`` so the contrastive loss and the parameter + gradient stay finite; without it, ``exp(large)`` -> +Inf -> a NaN + gradient that permanently corrupts the parameter. """ B, input_dim = 8, 32 - model = self._create_model(input_dim=input_dim, use_clip=True) + model = self._create_model(input_dim=input_dim, use_contrastive=True) model.train() model.init_loss() - # The temperatures live on the InfoNCE module that owns the clamp. - clip = model._loss_modules["sid_clip_loss"] - scales = (clip.logit_scale_self, clip.logit_scale_cl, clip.logit_scale) + # The temperatures live on the contrastive module that owns the clamp. + clip = model._loss_modules["contrastive_loss"] + scales = (clip.logit_scale_self, clip.logit_scale_cl, clip.logit_scale_ori) with torch.no_grad(): for p in scales: p.fill_(100.0) - batch = self._clip_batch(B, input_dim, torch.ones(B, 1)) + batch = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) losses = model.loss(model.predict(batch), batch) - self.assertTrue(torch.isfinite(losses["sid_clip_loss"])) + self.assertTrue(torch.isfinite(losses["contrastive_loss"])) sum(losses.values()).backward() for p in scales: self.assertIsNotNone(p.grad) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index dcff04f66..5ac44de19 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -336,7 +336,7 @@ def forward( cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) # Expose the per-layer cumulative quantized vectors (grad-carrying on the - # codebook side) so the model-side CommitmentLoss can consume them. + # codebook side) so the model-side SidCommitmentLoss can consume them. latents = torch.stack(cumulative, dim=1) # (B, n_layers, D) # Aggregate STE (STE only; Gumbel already carries grad). diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py index 975935d22..07c31fc9a 100644 --- a/tzrec/modules/sid/types.py +++ b/tzrec/modules/sid/types.py @@ -46,7 +46,7 @@ class ResidualQuantizerOutput(NamedTuple): The per-layer cumulative quantized vectors are exposed as ``latents`` so the model-side commitment loss - (:class:`~tzrec.loss.commitment_loss.CommitmentLoss`) can consume them. + (:class:`~tzrec.loss.sid_commitment_loss.SidCommitmentLoss`) can consume them. Attributes: cluster_ids (Tensor): codebook indices per layer, shape (B, n_layers). diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index e3d8a4296..cb2967aab 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -88,7 +88,8 @@ class VectorQuantizeLayer(QuantizeLayer): A gradient-trained ``nn.Embedding`` codebook (the VQ ``QuantizeLayer``), sibling of the K-Means backend's ``KMeansQuantizeLayer``. Maps inputs to a codebook entry via :meth:`quantize`. Loss-free: the commitment loss is - computed model-side by :class:`tzrec.loss.commitment_loss.CommitmentLoss` + computed model-side by + :class:`tzrec.loss.sid_commitment_loss.SidCommitmentLoss` over the quantizer's per-layer ``latents``. Sinkhorn optimal-transport assignment optionally balances codebook usage in training. diff --git a/tzrec/protos/loss.proto b/tzrec/protos/loss.proto index 5539b3c67..97d62c685 100644 --- a/tzrec/protos/loss.proto +++ b/tzrec/protos/loss.proto @@ -11,23 +11,23 @@ message LossConfig { } // Losses for semantic-ID (SID) generation models (SidRqvae). A SID model // lists one LossConfig per term it trains on (a reconstruction loss, the - // commitment loss, and optionally the CLIP contrastive loss). + // commitment loss, and optionally the pair contrastive loss). oneof sid_loss { - ReconLoss recon_loss = 6; - CommitmentLoss commitment_loss = 7; - SidClipLoss sid_clip_loss = 8; + SidReconLoss recon_loss = 6; + SidCommitmentLoss commitment_loss = 7; + SidContrastiveLoss contrastive_loss = 8; } } // RQ-VAE reconstruction loss (input vs. decoder output). -message ReconLoss { +message SidReconLoss { // Distance for the reconstruction term: "l2" (mse), "l1" or "cos". optional string recon_type = 1 [default = "l2"]; } // RQ-VAE commitment loss between the encoder output and the per-layer // cumulative quantized vectors. -message CommitmentLoss { +message SidCommitmentLoss { // Commitment loss weights [w1, w2] (encoder-toward-quant, quant-toward- // encoder). Defaults to [1.0, 0.5] when unset. repeated float latent_weight = 1; @@ -35,10 +35,10 @@ message CommitmentLoss { optional string commitment_type = 2 [default = "l2"]; } -// Enables the contrastive (masked InfoNCE) objective for a CLIP-style SID model. -// The paired-feature wiring lives on the model (SidRqvae.clip_config); this just -// turns the objective on (any loss hyperparameters would go here). -message SidClipLoss { +// Enables the pair contrastive (masked InfoNCE) objective for a SID model. The +// paired-feature wiring lives on the model (SidRqvae.contrastive_config); this +// just turns the objective on (any loss hyperparameters would go here). +message SidContrastiveLoss { } message BinaryCrossEntropy { diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index e43703ce8..5f2a373ff 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -26,16 +26,16 @@ message SinkhornConfig { optional bool enabled = 3 [default = true]; } -// CLIP-style dual-encoder wiring for SidRqvae: which paired feature group to -// encode and which group flags the contrastive-pair rows. This is model -// structure / input contract (declared on the model), not loss config. -message ClipConfig { +// Pair contrastive (dual-encoder) wiring for SidRqvae: which paired feature +// group to encode and which group flags the contrastive-pair rows. This is +// model structure / input contract (declared on the model), not loss config. +message ContrastiveConfig { // Name of the second (paired) embedding FEATURE GROUP (built by the same // EmbeddingGroup as the main input; same total dim as `feature_group`). - required string clip_feature_group = 1; + required string pair_feature_group = 1; // Name of the per-row pair-flag FEATURE GROUP (a single raw feature, // dim 1; >0.5 = contrastive pair). - required string clip_pair_feature_group = 2; + required string pair_flag_feature_group = 2; } message SidRqvae { @@ -77,12 +77,12 @@ message SidRqvae { // omitted: enabled with iters=5, epsilon=10.0. Include the sub-config // to override params; set ``enabled: false`` inside it to disable. optional SinkhornConfig sinkhorn_config = 15; - // CLIP-style dual-encoder structure: when set, the model encodes a second - // (paired) feature group and runs the contrastive path. This declares the - // model's input contract + topology; the contrastive OBJECTIVE is enabled - // separately by a `sid_clip_loss` entry in ModelConfig.losses (both must be - // set together). - optional ClipConfig clip_config = 16; + // Pair contrastive (dual-encoder) structure: when set, the model encodes a + // second (paired) feature group and runs the contrastive path. This declares + // the model's input contract + topology; the contrastive OBJECTIVE is enabled + // separately by a `contrastive_loss` entry in ModelConfig.losses (both must + // be set together). + optional ContrastiveConfig contrastive_config = 16; // Reconstruction, commitment and (optional) contrastive losses are configured // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the CLIP @@ -90,8 +90,10 @@ message SidRqvae { // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup // from ModelConfig.feature_groups). May hold one or many content/side-info - // features; their concatenated dim is the encoder input_dim. - optional string feature_group = 40 [default = "deep"]; + // features; their concatenated dim is the encoder input_dim. Optional: when + // unset and exactly one feature group is declared, that group is used; with + // multiple groups it must be set explicitly. + optional string feature_group = 40; } message SidRqkmeans { @@ -115,6 +117,8 @@ message SidRqkmeans { // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup // from ModelConfig.feature_groups). May hold one or many content/side-info - // features; their concatenated dim is the K-Means dimension. - optional string feature_group = 40 [default = "deep"]; + // features; their concatenated dim is the K-Means dimension. Optional: when + // unset and exactly one feature group is declared, that group is used; with + // multiple groups it must be set explicitly. + optional string feature_group = 40; } diff --git a/tzrec/tests/configs/sid_rqvae_clip_mock.config b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config similarity index 66% rename from tzrec/tests/configs/sid_rqvae_clip_mock.config rename to tzrec/tests/configs/sid_rqvae_contrastive_mock.config index 9a1cf9fa1..7c4154afd 100644 --- a/tzrec/tests/configs/sid_rqvae_clip_mock.config +++ b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config @@ -1,6 +1,6 @@ train_input_path: "" eval_input_path: "" -model_dir: "experiments/sid_rqvae_clip_mock" +model_dir: "experiments/sid_rqvae_contrastive_mock" train_config { sparse_optimizer { adagrad_optimizer { @@ -36,15 +36,15 @@ feature_configs { } feature_configs { raw_feature { - feature_name: "image_emb" - expression: "item:image_embedding" + feature_name: "pair_emb" + expression: "item:pair_embedding" value_dim: 16 } } feature_configs { raw_feature { - feature_name: "is_clip_pair" - expression: "item:is_clip_pair" + feature_name: "is_pair" + expression: "item:is_pair" value_dim: 1 } } @@ -55,13 +55,13 @@ model_config { group_type: DEEP } feature_groups { - group_name: "clip_image" - feature_names: "image_emb" + group_name: "pair" + feature_names: "pair_emb" group_type: DEEP } feature_groups { - group_name: "clip_pair" - feature_names: "is_clip_pair" + group_name: "pair_flag" + feature_names: "is_pair" group_type: DEEP } sid_rqvae { @@ -72,11 +72,13 @@ model_config { codebook: 16 forward_mode: "ste" kmeans_init: false - # clip_image shares the encoder with "deep", so it must match its dim; - # clip_pair flags rows (>0.5 = pair). Objective: sid_clip_loss below. - clip_config { - clip_feature_group: "clip_image" - clip_pair_feature_group: "clip_pair" + # Multiple feature groups -> name the main group explicitly. + feature_group: "deep" + # "pair" shares the encoder with "deep", so it must match its dim; + # "pair_flag" flags rows (>0.5 = pair). Objective: contrastive_loss below. + contrastive_config { + pair_feature_group: "pair" + pair_flag_feature_group: "pair_flag" } } losses { @@ -91,7 +93,7 @@ model_config { } } losses { - sid_clip_loss { + contrastive_loss { } } } From 214c755aa8a3e786fc8dc4317a566b03275882ea Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 06:24:46 +0000 Subject: [PATCH 124/129] [refactor] SID: strip comments that restate logic in the contrastive loss/base Comment-refinement pass over the refactor: drop the temperature/L2-normalize restate comments in SidContrastiveLoss.forward and the "shared config fields" section label in BaseSidModel; reword the label-refresh comment to state why (it carries the cross-rank offset). Non-obvious-trick comments (overflow clamp, NaN backstop, neg_fill dtype, operand pairing, fail-fast, trailing-Linear) are kept. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/sid_contrastive_loss.py | 4 +--- tzrec/models/sid_model.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tzrec/loss/sid_contrastive_loss.py b/tzrec/loss/sid_contrastive_loss.py index 768df25c0..2c2b56ed7 100644 --- a/tzrec/loss/sid_contrastive_loss.py +++ b/tzrec/loss/sid_contrastive_loss.py @@ -136,21 +136,19 @@ def forward( Returns: Tensor: scalar mean of the three contrastive terms (self/ori/cl). """ - # The three contrastive temperatures, clamped (<= ln 100) then exp'd. logit_scale_self = self._scaled(self.logit_scale_self) logit_scale_ori = self._scaled(self.logit_scale_ori) logit_scale_cl = self._scaled(self.logit_scale_cl) local_batch_size = embed_a.size(0) - # Update labels when batch size changes (multi-GPU offset) + # Labels carry the cross-rank offset, so refresh them on batch-size change. if local_batch_size != self.last_local_batch_size: self.labels = local_batch_size * self._rank + torch.arange( local_batch_size, device=embed_a.device ) self.last_local_batch_size = local_batch_size - # L2 normalize the reconstructed features embed_a = F.normalize(embed_a, dim=-1, p=2) embed_b = F.normalize(embed_b, dim=-1, p=2) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 44c5e80ec..b03cfb60e 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -91,7 +91,6 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config - # Config fields shared by every SID proto message. self._normalize_residuals = cfg.normalize_residuals if not cfg.codebook: From 6a62c6a8048f367fc52ed6b3192fc795ca331706 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 06:40:13 +0000 Subject: [PATCH 125/129] [refactor] SID: strip restate comments (PR-wide comment refinement) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove inline shape annotations (# (B, D) etc.) and other comments that merely restate the adjacent line — the wider codebase does not use inline shape comments, and they are derivable. Kept: public docstrings, and comments that explain a non-obvious trick or rationale (the Eq 4.2 Householder note, the trailing-Linear / unbounded-projection rationale, "first training forward only", "residual, in place", "differentiable", the contrastive col-mask purpose and safe-labels fallback, the finfo/overflow and detach-swap notes). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/models/sid_rqvae.py | 1 - .../modules/sid/residual_vector_quantizer.py | 25 ++++++++----------- tzrec/modules/sid/vector_quantize.py | 8 +++--- 3 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index efdd139e6..ca1790063 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -93,7 +93,6 @@ def __init__( MLP(self._input_dim, hidden_units=hidden_dims), nn.Linear(hidden_dims[-1], embed_dim), ) - # Decoder mirrors the encoder over the reversed hidden stack. self._decoder = nn.Sequential( MLP(embed_dim, hidden_units=list(reversed(hidden_dims))), nn.Linear(hidden_dims[0], self._input_dim), diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 5ac44de19..b9f25b803 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -63,7 +63,7 @@ def faiss_residual_kmeans( res_centers: List[torch.Tensor] = [] for n_clusters in n_clusters_list: km = faiss_kmeans_fit(x, D, n_clusters, faiss_kmeans_kwargs) - centroids = km.centroids.copy() # (K, D) + centroids = km.centroids.copy() res_centers.append(torch.from_numpy(centroids).to(device)) _, idx = km.index.search(x, 1) x -= centroids[idx.ravel()] # residual, in place @@ -255,26 +255,24 @@ def _apply_rotation_trick( quant_detached = quant.detach() x_detached = x.detach() - quant_norms = torch.linalg.vector_norm(quant_detached, dim=-1).unsqueeze( - 1 - ) # (B, 1) - x_norms = torch.linalg.vector_norm(x_detached, dim=-1).unsqueeze(1) # (B, 1) - lambda_ = quant_norms / (x_norms + 1e-8) # (B, 1) + quant_norms = torch.linalg.vector_norm(quant_detached, dim=-1).unsqueeze(1) + x_norms = torch.linalg.vector_norm(x_detached, dim=-1).unsqueeze(1) + lambda_ = quant_norms / (x_norms + 1e-8) - x_hat = x_detached / (x_norms + 1e-8) # (B, D) - quant_hat = quant_detached / (quant_norms + 1e-8) # (B, D) + x_hat = x_detached / (x_norms + 1e-8) + quant_hat = quant_detached / (quant_norms + 1e-8) - normalized_sum = F.normalize(x_hat + quant_hat, p=2, dim=1) # (B, D) + normalized_sum = F.normalize(x_hat + quant_hat, p=2, dim=1) - x_unsq = x.unsqueeze(1) # (B, 1, D) + x_unsq = x.unsqueeze(1) # Eq 4.2: Householder reflection sum_projection = ( x_unsq @ normalized_sum.unsqueeze(2) @ normalized_sum.unsqueeze(1) - ) # (B, 1, D) + ) rescaled_embeddings = ( x_unsq @ x_hat.unsqueeze(2) @ quant_hat.unsqueeze(1) - ) # (B, 1, D) + ) return lambda_ * ( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) @@ -331,13 +329,12 @@ def forward( self.training and self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX ) - # cumulative[i] = sum after layer i. walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) # Expose the per-layer cumulative quantized vectors (grad-carrying on the # codebook side) so the model-side SidCommitmentLoss can consume them. - latents = torch.stack(cumulative, dim=1) # (B, n_layers, D) + latents = torch.stack(cumulative, dim=1) # Aggregate STE (STE only; Gumbel already carries grad). quants_trunc = aggregated_quants diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index cb2967aab..035805980 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -79,7 +79,7 @@ def _sinkhorn( # Step 4: scale back so columns sum to 1 (assignment) Q *= B - return Q.t() # (B, K) + return Q.t() class VectorQuantizeLayer(QuantizeLayer): @@ -159,7 +159,7 @@ def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor: pairwise distances, shape (B, n_embed). """ - codebook = self.embedding.weight # (n_embed, D) + codebook = self.embedding.weight if self.distance_type == "l2": distances = torch.cdist(x, codebook, p=2).pow(2) @@ -188,7 +188,7 @@ def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: Returns: Tensor: codebook indices, shape (B,). """ - distances = self._compute_distances(x) # (B, n_embed) + distances = self._compute_distances(x) if self.training and self.use_sinkhorn: # Sinkhorn requires non-negative cost; z-score then shift. @@ -224,7 +224,7 @@ def quantize(self, x: torch.Tensor) -> QuantizeOutput: # Gumbel: grad-enabled distances feed the encoder; the hard sample drives # both emb and ids, so the saved code matches the vector used. if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: - logits = -self._compute_distances(x) # (B, n_embed), differentiable + logits = -self._compute_distances(x) # differentiable weights = F.gumbel_softmax( logits, tau=self.gumbel_temperature, hard=True, dim=-1 ) From 527da1825aa0341f293c8edd450fd135a818d542 Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 06:54:29 +0000 Subject: [PATCH 126/129] [refactor] SID: streamline narration/rationale comments Stricter comment pass: strip comments that narrate control flow / structure or restate the adjacent code (the "default to no contrastive path", "codes-only path", Sinkhorn "Step 1..4" labels, "first training forward only", the gumbel/sinkhorn auto-disable narration whose logged warning already says it, etc.). Condense the trailing-Linear and codebook-freeze notes. Kept only what guards a real mistake: comments whose removal would let a future edit reintroduce a bug (the codebook-freeze "no per-layer STE wrap", the DDP rank-0 faiss-broadcast deadlock ordering, the sinkhorn_epsilon>0 overflow and non-negative-cost requirements), external refs (Eq 4.2 Householder), the non-obvious contrastive operand-pairing table, and public docstrings. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/sid_contrastive_loss.py | 2 -- tzrec/models/sid_model.py | 4 --- tzrec/models/sid_rqkmeans.py | 4 --- tzrec/models/sid_rqvae.py | 12 ++------- .../modules/sid/residual_kmeans_quantizer.py | 1 - .../modules/sid/residual_vector_quantizer.py | 18 ++----------- tzrec/modules/sid/vector_quantize.py | 27 ++++--------------- 7 files changed, 9 insertions(+), 59 deletions(-) diff --git a/tzrec/loss/sid_contrastive_loss.py b/tzrec/loss/sid_contrastive_loss.py index 2c2b56ed7..53b5a6955 100644 --- a/tzrec/loss/sid_contrastive_loss.py +++ b/tzrec/loss/sid_contrastive_loss.py @@ -157,11 +157,9 @@ def forward( self._all_gather_with_grad([embed_a, embed_b, embed_a_ori, embed_b_ori]) ) - # Column mask: drop non-pair columns from the negatives. pair_mask_all = self._gather_bool_mask(pair_mask) col_mask = (~pair_mask_all).unsqueeze(0) # (1, B_global) - # Safe labels: non-pair rows fall back to the first pair column. labels = self.labels fallback = pair_mask.long().argmax() # first pair sample index safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels)) diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index b03cfb60e..f5b98c132 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -103,10 +103,6 @@ def __init__( ) self._n_layers = len(self._n_embed_list) - # Built in the base __init__ (not the subclass like Rank/Match models) - # so _input_dim is ready before the subclass builds its encoder; derived - # from the main group's total dim (which may concatenate several - # content + side-info features). self.init_input() self._feature_group = self._resolve_feature_group() self._input_dim = self.embedding_group.group_total_dim(self._feature_group) diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 2057680d7..933481fa2 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -139,10 +139,6 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - # Expose the centroid-sum reconstruction (``x_hat``) + its target for - # update_metric only once fitted — pre-fit x_hat is all-zeros, so - # omitting it skips the eval metrics. (Meaningful only with - # normalize_residuals=False.) if self.is_eval and self._quantizer.is_fitted: predictions["x_hat"] = quantized predictions["recon_target"] = embedding diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py index ca1790063..732c520b4 100644 --- a/tzrec/models/sid_rqvae.py +++ b/tzrec/models/sid_rqvae.py @@ -81,14 +81,10 @@ def __init__( if any(h < 1 for h in hidden_dims): raise ValueError(f"every hidden_dims entry must be >= 1, got {hidden_dims}") - # Sinkhorn params from the proto: config_to_kwargs flows the proto - # defaults (enabled=True, iters=5, epsilon=10.0) so the model never - # restates them; keys map to the quantizer's use_sinkhorn/iters/epsilon. sinkhorn_cfg = config_to_kwargs(cfg.sinkhorn_config) - # Framework MLP (Linear+ReLU per hidden) + a bare trailing Linear: the - # latent / reconstruction must be unbounded and MLP always activates its - # last layer, so the projection carries no activation. + # MLP activates its last layer; the trailing bare Linear keeps the + # latent / reconstruction unbounded. self._encoder = nn.Sequential( MLP(self._input_dim, hidden_units=hidden_dims), nn.Linear(hidden_dims[-1], embed_dim), @@ -134,8 +130,6 @@ def _init_contrastive(self) -> None: ``embedding_group`` / ``_input_dim``. """ cfg = self._model_config - # Default to no contrastive path; the group names stay None unless - # contrastive_config is set. self._pair_feature_group = None self._pair_flag_feature_group = None self._use_contrastive = cfg.HasField("contrastive_config") @@ -197,8 +191,6 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: grouped = self.build_input(batch) embedding = grouped[self._feature_group] if self._is_inference: - # Codes-only path: get_codes does just the residual walk (no decode, - # no commitment latents), so neither dual-path branch is needed. return {"codes": self._quantizer.get_codes(self._encode(embedding))} if self._use_contrastive: return self._predict_mixed(grouped) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index b9433a3c6..e94594be8 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -199,7 +199,6 @@ def train_offline( # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit (SidRqkmeans refuses CUDA). if verbose: logger.info( "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index b9f25b803..2c590df27 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -57,7 +57,6 @@ def faiss_residual_kmeans( """ device = samples.device _, D = samples.shape - # Own a contiguous fp32 numpy copy we mutate in place to form residuals. x = samples.detach().cpu().float().numpy().copy() res_centers: List[torch.Tensor] = [] @@ -66,7 +65,7 @@ def faiss_residual_kmeans( centroids = km.centroids.copy() res_centers.append(torch.from_numpy(centroids).to(device)) _, idx = km.index.search(x, 1) - x -= centroids[idx.ravel()] # residual, in place + x -= centroids[idx.ravel()] return res_centers @@ -125,8 +124,6 @@ def __init__( super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) self.rotation_trick = rotation_trick - # ``initted`` is the kmeans_init guard: True means "codebook has - # been seeded", so init_embed_() becomes a no-op on later forwards. self.register_buffer("initted", torch.tensor([not kmeans_init])) if forward_mode not in self._FORWARD_MODE_MAP: @@ -137,12 +134,9 @@ def __init__( mode_enum = self._FORWARD_MODE_MAP[forward_mode] self._forward_mode = mode_enum is_gumbel = mode_enum == QuantizeForwardMode.GUMBEL_SOFTMAX - # Sinkhorn is incompatible with Gumbel; auto-disable (the proto default - # is on) instead of crashing. if is_gumbel and use_sinkhorn: logger.warning("gumbel_softmax: disabling incompatible use_sinkhorn.") use_sinkhorn = False - # Gumbel skips the aggregate STE, so the rotation trick is unused. if is_gumbel and rotation_trick: logger.warning("gumbel_softmax: rotation_trick has no effect; ignoring.") @@ -296,11 +290,6 @@ def _quantize_layer( emb (Tensor): the raw codebook vector (STE/eval) or the soft embedding (Gumbel), with grad, shape (B, D). """ - # On the STE residual walk the residual is detached and the layer - # returns the raw codebook vector (grad-carrying on the codebook, no - # per-layer STE wrap); the encoder STE gradient is applied once on the - # aggregate in :meth:`forward`. Gumbel returns the soft embedding that - # carries grad directly. out = self.layers[layer_idx].quantize(residual) return out.ids, out.embeddings @@ -323,7 +312,7 @@ def forward( latents). """ if self.training: - self.init_embed_(input) # first training forward only + self.init_embed_(input) train_gumbel = ( self.training and self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX @@ -332,11 +321,8 @@ def forward( walk_input = input if train_gumbel else input.detach() cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) - # Expose the per-layer cumulative quantized vectors (grad-carrying on the - # codebook side) so the model-side SidCommitmentLoss can consume them. latents = torch.stack(cumulative, dim=1) - # Aggregate STE (STE only; Gumbel already carries grad). quants_trunc = aggregated_quants if self.training and not train_gumbel: if self.rotation_trick: diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py index 035805980..4c41b3fbd 100644 --- a/tzrec/modules/sid/vector_quantize.py +++ b/tzrec/modules/sid/vector_quantize.py @@ -48,36 +48,29 @@ def _sinkhorn( Tensor: assignment matrix, shape (B, K). Use Q.argmax(dim=-1) externally to get codebook indices. """ - # Step 1: exponential kernel transform (B, K) -> (K, B) Q = torch.exp(-cost * epsilon).t() - # Global batch size for distributed training if dist.is_initialized(): B = Q.size(1) * dist.get_world_size() else: B = Q.size(1) K = Q.size(0) - # Step 2: global normalization — make matrix sum to 1 sum_Q = torch.sum(Q) if dist.is_initialized(): dist.all_reduce(sum_Q) Q /= sum_Q + 1e-8 - # Step 3: alternating row-column normalization for _ in range(n_iters): - # Row normalization: each prototype's total weight = 1/K sum_of_rows = torch.sum(Q, dim=1, keepdim=True) if dist.is_initialized(): dist.all_reduce(sum_of_rows) Q /= sum_of_rows + 1e-8 Q /= K - # Column normalization: each sample's total weight = 1/B Q /= torch.sum(Q, dim=0, keepdim=True) + 1e-8 Q /= B - # Step 4: scale back so columns sum to 1 (assignment) Q *= B return Q.t() @@ -121,9 +114,6 @@ def __init__( gumbel_temperature: float = 1.0, ) -> None: super().__init__(n_embed=n_embed, embed_dim=embed_dim) - # Sinkhorn drives `ids` (balanced assignment), Gumbel drives `emb` - # (nearest code); combining them makes the saved id and embedding - # diverge, so reject the combo (see the assert message). _is_gumbel = forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX assert not (use_sinkhorn and _is_gumbel), ( "use_sinkhorn=True is incompatible with forward_mode=GUMBEL_SOFTMAX: " @@ -135,7 +125,6 @@ def __init__( # (large, shifted) cost overflows to +Inf -> NaN assignments. if use_sinkhorn and sinkhorn_epsilon <= 0: raise ValueError(f"sinkhorn_epsilon must be > 0, got {sinkhorn_epsilon}") - # ``n_embed`` / ``embed_dim`` are owned by the QuantizeLayer base. self.forward_mode = forward_mode self.distance_type = distance_type self.use_sinkhorn = use_sinkhorn @@ -196,7 +185,6 @@ def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: distances = (distances - mean) / std.add(1e-12) distances = distances - distances.min() - # Sinkhorn optimal-transport assignment Q = _sinkhorn( distances, n_iters=self.sinkhorn_iters, @@ -221,10 +209,8 @@ def quantize(self, x: torch.Tensor) -> QuantizeOutput: Returns: QuantizeOutput: named tuple of (embeddings, ids). """ - # Gumbel: grad-enabled distances feed the encoder; the hard sample drives - # both emb and ids, so the saved code matches the vector used. if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: - logits = -self._compute_distances(x) # differentiable + logits = -self._compute_distances(x) weights = F.gumbel_softmax( logits, tau=self.gumbel_temperature, hard=True, dim=-1 ) @@ -232,13 +218,10 @@ def quantize(self, x: torch.Tensor) -> QuantizeOutput: ids = weights.argmax(dim=-1) return QuantizeOutput(embeddings=emb, ids=ids) - # STE / eval: nearest-neighbour assignment under no_grad, one codebook - # gather. Return the RAW codebook vector (grad-carrying to the codebook) - # so the residual quantizer's cumulative ``latents`` trains the codebook - # via the commitment loss. The encoder straight-through gradient is - # applied once on the aggregate in ``ResidualVectorQuantizer.forward``; a - # per-layer STE wrap here would detach the codebook from ``latents`` and - # leave it frozen at init. + # Return the RAW codebook vector (no per-layer STE wrap): the aggregate + # STE in ResidualVectorQuantizer.forward routes the encoder gradient, + # while a wrap here would detach the codebook from ``latents`` and freeze + # it at init. ids = self._find_nearest_embedding(x) return QuantizeOutput(embeddings=self.embedding(ids), ids=ids) From 6eca7c7a93660ce829a2bd10224319dd8b2d4fce Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 06:58:37 +0000 Subject: [PATCH 127/129] [refactor] SID: ruff-format residual_vector_quantizer Run ruff format: stripping a trailing shape comment left the rescaled_embeddings expression short enough to fit on one line, so the formatter collapses the parenthesized form. Fixes the RunCodeStyleCI (ruff-format) failure on the prior commit. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/residual_vector_quantizer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py index 2c590df27..560bea7bd 100644 --- a/tzrec/modules/sid/residual_vector_quantizer.py +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -264,9 +264,7 @@ def _apply_rotation_trick( sum_projection = ( x_unsq @ normalized_sum.unsqueeze(2) @ normalized_sum.unsqueeze(1) ) - rescaled_embeddings = ( - x_unsq @ x_hat.unsqueeze(2) @ quant_hat.unsqueeze(1) - ) + rescaled_embeddings = x_unsq @ x_hat.unsqueeze(2) @ quant_hat.unsqueeze(1) return lambda_ * ( x_unsq - 2 * sum_projection + 2 * rescaled_embeddings ).squeeze(1) From d8e24d98f4c751a3a2efd12364ec2d71be80b96d Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 07:27:24 +0000 Subject: [PATCH 128/129] [refactor] SID: SidReconLoss masked-mean; drop feature_group field (auto-detect) - SidReconLoss now owns the masked-mean reduction: forward(x_hat, x, mask) returns the scalar reconstruction loss, so all three SID loss modules take their operands/mask and return a scalar uniformly. Drop _masked_mean + div_no_nan from sid_model.py; its tests move to sid_recon_loss_test.py (sid_model_test.py held only those, so it is removed). - Drop the SidRqvae/SidRqkmeans `feature_group` proto field and _resolve_feature_group: the main input is just group_names()[0] (the first declared feature group), per the maintainer's "auto detect group_names()[0]". Single-group models need no field; the contrastive path names its paired groups in contrastive_config. Configs/tests drop the explicit feature_group. - Audit fixes in these files: stale "CLIP feature wiring" proto comment -> "contrastive"; rename a leftover `clip` test local to `contrastive`. Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/loss/sid_recon_loss.py | 57 ++++++++++++++----- tzrec/loss/sid_recon_loss_test.py | 39 +++++++++++-- tzrec/models/sid_model.py | 55 ++---------------- tzrec/models/sid_model_test.py | 39 ------------- tzrec/models/sid_rqvae_test.py | 13 ++--- tzrec/protos/models/sid_model.proto | 18 +----- .../configs/sid_rqvae_contrastive_mock.config | 6 +- 7 files changed, 92 insertions(+), 135 deletions(-) delete mode 100644 tzrec/models/sid_model_test.py diff --git a/tzrec/loss/sid_recon_loss.py b/tzrec/loss/sid_recon_loss.py index c6ba68fdc..02724f875 100644 --- a/tzrec/loss/sid_recon_loss.py +++ b/tzrec/loss/sid_recon_loss.py @@ -9,20 +9,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""SidReconLoss: per-row RQ-VAE reconstruction distance (input vs. decoder).""" +"""SidReconLoss: mask-aware RQ-VAE reconstruction loss (input vs. decoder).""" + +from typing import Optional import torch from torch.nn import functional as F from torch.nn.modules.loss import _Loss +from tzrec.modules.utils import div_no_nan + class SidReconLoss(_Loss): - """Per-row reconstruction distance for the configured ``recon_type``. + """Reconstruction loss for RQ-VAE: per-row distance reduced to a scalar. - ``forward(x_hat, x)`` returns the per-row distance ``(B,)`` - (``reduction="none"``); the model reduces it (a masked mean over the - reconstruction rows). Registered as a ``_loss_modules`` entry alongside the - commitment / contrastive losses. + ``forward(x_hat, x, mask)`` computes the per-row distance for the configured + ``recon_type`` and averages it over the masked-in rows (all rows if ``mask`` + is None; the mixed recon+contrastive path passes ``recon_mask`` to score the + reconstruction-only rows). Registered as a ``_loss_modules`` entry alongside + the commitment / contrastive losses and, like them, returns a scalar. Args: recon_type (str): the distance, ``"l2"`` (mse), ``"l1"`` or ``"cos"``. @@ -37,18 +42,42 @@ def __init__(self, recon_type: str = "l2") -> None: ) self.recon_type = recon_type - def forward(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - """Per-row reconstruction distance. + def _per_row(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Per-row reconstruction distance, shape (B,).""" + if self.recon_type == "l1": + return F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + if self.recon_type == "cos": + return 1 - F.cosine_similarity(x_hat, x, dim=-1) + return F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) # "l2" + + @staticmethod + def _masked_mean( + per_sample: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Mean over the masked-in rows (all rows if ``mask`` is None). + + The masked mean divides by the valid-row count (``div_no_nan`` keeps an + empty mask at 0). No data-dependent branching -> ``torch.compile``-friendly. + """ + if mask is None: + return per_sample.mean() + mask = mask.float() + return div_no_nan((per_sample * mask).sum(), mask.sum()) + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Mask-aware reconstruction loss. Args: x_hat (Tensor): reconstruction (decoder output), shape (B, D). x (Tensor): the input it reconstructs, shape (B, D). + mask (Tensor, optional): per-row bool; rows to include (all if None). Returns: - Tensor: per-row distance, shape (B,). + Tensor: scalar reconstruction loss. """ - if self.recon_type == "l1": - return F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) - if self.recon_type == "cos": - return 1 - F.cosine_similarity(x_hat, x, dim=-1) - return F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) # "l2" + return self._masked_mean(self._per_row(x_hat, x), mask) diff --git a/tzrec/loss/sid_recon_loss_test.py b/tzrec/loss/sid_recon_loss_test.py index eafa37f51..2d4a27e7a 100644 --- a/tzrec/loss/sid_recon_loss_test.py +++ b/tzrec/loss/sid_recon_loss_test.py @@ -18,27 +18,28 @@ class SidReconLossTest(unittest.TestCase): - """Tests for the per-row reconstruction-distance module.""" + """Tests for the reconstruction-loss module (per-row distance + reduction).""" def test_l2_is_per_row_mse(self) -> None: - d = SidReconLoss("l2")(torch.ones(3, 4), torch.zeros(3, 4)) + d = SidReconLoss("l2")._per_row(torch.ones(3, 4), torch.zeros(3, 4)) self.assertEqual(d.shape, (3,)) torch.testing.assert_close(d, torch.ones(3)) # mean of 1^2 over dim -1 def test_l1_is_per_row_mae(self) -> None: - d = SidReconLoss("l1")(torch.ones(2, 5), torch.zeros(2, 5)) + d = SidReconLoss("l1")._per_row(torch.ones(2, 5), torch.zeros(2, 5)) torch.testing.assert_close(d, torch.ones(2)) def test_cos_is_one_minus_cosine(self) -> None: x = torch.tensor([[1.0, 0.0]]) # identical vectors -> cosine 1 -> distance 0 - d = SidReconLoss("cos")(x, x.clone()) + d = SidReconLoss("cos")._per_row(x, x.clone()) torch.testing.assert_close(d, torch.zeros(1), atol=1e-6, rtol=0) @parameterized.expand([("l2",), ("l1",), ("cos",)]) - def test_each_type_finite_and_backprops(self, recon_type) -> None: + def test_each_type_scalar_and_backprops(self, recon_type) -> None: x_hat = torch.randn(4, 6, requires_grad=True) - loss = SidReconLoss(recon_type)(x_hat, torch.randn(4, 6)).mean() + loss = SidReconLoss(recon_type)(x_hat, torch.randn(4, 6)) + self.assertEqual(loss.shape, ()) # forward reduces to a scalar self.assertTrue(torch.isfinite(loss)) loss.backward() # grad must flow back to the (decoder) input self.assertIsNotNone(x_hat.grad) @@ -47,6 +48,32 @@ def test_unknown_type_raises(self) -> None: with self.assertRaisesRegex(ValueError, "recon_type"): SidReconLoss("nope") + # --- masked-mean reduction (forward's mask handling) --- + + def test_no_mask_is_plain_mean(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + torch.testing.assert_close(SidReconLoss._masked_mean(x), x.mean()) + + def test_mask_averages_over_valid_rows_only(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + mask = torch.tensor([True, False, True, False]) + torch.testing.assert_close( + SidReconLoss._masked_mean(x, mask), torch.tensor(2.0) + ) # (1+3)/2 + + def test_empty_mask_is_zero_not_nan(self) -> None: + out = SidReconLoss._masked_mean( + torch.tensor([1.0, 2.0, 3.0]), torch.zeros(3, dtype=torch.bool) + ) + self.assertEqual(out.item(), 0.0) + + def test_forward_applies_mask(self) -> None: + # l1 per-row of [[1,1],[2,2],[3,3],[4,4]] vs 0 is [1,2,3,4]; mask keeps 0,2. + x_hat = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]]) + mask = torch.tensor([True, False, True, False]) + loss = SidReconLoss("l1")(x_hat, torch.zeros(4, 2), mask) + torch.testing.assert_close(loss, torch.tensor(2.0)) # (1+3)/2 + if __name__ == "__main__": unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index f5b98c132..96fb4d77d 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -25,30 +25,10 @@ from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel from tzrec.modules.embedding import EmbeddingGroup -from tzrec.modules.utils import div_no_nan from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig -def _masked_mean( - per_sample: torch.Tensor, mask: Optional[torch.Tensor] = None -) -> torch.Tensor: - """Mean of a per-row loss over the masked-in rows (all rows if ``mask`` None). - - The mixed recon+contrastive path applies the reconstruction loss to recon rows - only; the masked mean divides by the valid-row count (``div_no_nan`` keeps an - empty mask at 0). No data-dependent branching → ``torch.compile``-friendly. - - Args: - per_sample (Tensor): per-row loss, shape (B,). - mask (Tensor, optional): per-row bool; rows to include. - """ - if mask is None: - return per_sample.mean() - mask = mask.float() - return div_no_nan((per_sample * mask).sum(), mask.sum()) - - class BaseSidModel(BaseModel): """Shared base for semantic-ID (SID) generation models. @@ -104,7 +84,7 @@ def __init__( self._n_layers = len(self._n_embed_list) self.init_input() - self._feature_group = self._resolve_feature_group() + self._feature_group = self.embedding_group.group_names()[0] self._input_dim = self.embedding_group.group_total_dim(self._feature_group) if self._input_dim < 1: raise ValueError( @@ -120,29 +100,6 @@ def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: """Build grouped input features: ``{group_name: (B, group_total_dim)}``.""" return self.embedding_group(batch) - def _resolve_feature_group(self) -> str: - """Resolve the main input feature group name. - - Uses ``feature_group`` when set; otherwise, when exactly one group is - declared, that sole group (DLRM-style auto-detect); otherwise fails as - ambiguous. The resolved name must exist in the model's feature groups. - """ - groups = self.embedding_group.group_names() - if self._model_config.HasField("feature_group"): - name = self._model_config.feature_group - if name not in groups: - raise ValueError( - f"feature_group {name!r} is not in model_config.feature_groups " - f"{groups}" - ) - return name - if len(groups) == 1: - return groups[0] - raise ValueError( - "feature_group must be set when multiple feature_groups are declared, " - f"got groups {groups}" - ) - def init_loss(self) -> None: """Initialize SID loss modules from ``ModelConfig.losses``. @@ -203,12 +160,12 @@ def _sid_loss_impl( """Compute one ``sid_loss`` term from ``predictions``.""" loss_type = loss_cfg.WhichOneof("sid_loss") if loss_type == "recon_loss": - per_sample = self._loss_modules["recon_loss"]( - predictions["x_hat"], predictions["recon_target"] + loss = self._loss_modules["recon_loss"]( + predictions["x_hat"], + predictions["recon_target"], + predictions.get("recon_mask"), ) - return { - "recon_loss": _masked_mean(per_sample, predictions.get("recon_mask")) - } + return {"recon_loss": loss} elif loss_type == "commitment_loss": loss = self._loss_modules["commitment_loss"]( predictions["encoder_out"], predictions["latents"] diff --git a/tzrec/models/sid_model_test.py b/tzrec/models/sid_model_test.py deleted file mode 100644 index f4eadcf3d..000000000 --- a/tzrec/models/sid_model_test.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2026, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch - -from tzrec.models.sid_model import _masked_mean - - -class MaskedMeanTest(unittest.TestCase): - """Tests for the shared ``BaseSidModel`` masked-mean reduction.""" - - def test_no_mask_is_plain_mean(self) -> None: - x = torch.tensor([1.0, 2.0, 3.0, 4.0]) - torch.testing.assert_close(_masked_mean(x), x.mean()) - - def test_mask_averages_over_valid_rows_only(self) -> None: - x = torch.tensor([1.0, 2.0, 3.0, 4.0]) - mask = torch.tensor([True, False, True, False]) - torch.testing.assert_close(_masked_mean(x, mask), torch.tensor(2.0)) # (1+3)/2 - - def test_empty_mask_is_zero_not_nan(self) -> None: - out = _masked_mean( - torch.tensor([1.0, 2.0, 3.0]), torch.zeros(3, dtype=torch.bool) - ) - self.assertEqual(out.item(), 0.0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py index de0ac8757..292fe552e 100644 --- a/tzrec/models/sid_rqvae_test.py +++ b/tzrec/models/sid_rqvae_test.py @@ -117,8 +117,6 @@ def _create_model( ) losses = [_recon_loss_cfg(recon), _commitment_cfg()] if use_contrastive: - # Multiple feature groups -> the main group must be named explicitly. - sid_rqvae_cfg.feature_group = "deep" sid_rqvae_cfg.contrastive_config.pair_feature_group = "pair" sid_rqvae_cfg.contrastive_config.pair_flag_feature_group = "pair_flag" losses.append(_contrastive_cfg()) @@ -302,7 +300,6 @@ def test_pair_feature_group_dim_mismatch_raises(self) -> None: 32, use_contrastive=True, pair_emb_dim=16 ) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.feature_group = "deep" cfg.contrastive_config.pair_feature_group = "pair" cfg.contrastive_config.pair_flag_feature_group = "pair_flag" model_config = model_pb2.ModelConfig( @@ -317,7 +314,6 @@ def test_pair_flag_group_must_be_dim_1(self) -> None: 32, use_contrastive=True, flag_dim=3 ) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.feature_group = "deep" cfg.contrastive_config.pair_feature_group = "pair" cfg.contrastive_config.pair_flag_feature_group = "pair_flag" model_config = model_pb2.ModelConfig( @@ -330,7 +326,6 @@ def test_contrastive_group_missing_raises(self) -> None: """A typo'd contrastive group name fails fast at init, not on forward.""" features, feature_groups = _features_and_groups(32, use_contrastive=True) cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) - cfg.feature_group = "deep" cfg.contrastive_config.pair_feature_group = "pair" cfg.contrastive_config.pair_flag_feature_group = "pair_flagTYPO" model_config = model_pb2.ModelConfig( @@ -443,8 +438,12 @@ def test_logit_scale_clamped_prevents_overflow(self) -> None: model.train() model.init_loss() # The temperatures live on the contrastive module that owns the clamp. - clip = model._loss_modules["contrastive_loss"] - scales = (clip.logit_scale_self, clip.logit_scale_cl, clip.logit_scale_ori) + contrastive = model._loss_modules["contrastive_loss"] + scales = ( + contrastive.logit_scale_self, + contrastive.logit_scale_cl, + contrastive.logit_scale_ori, + ) with torch.no_grad(): for p in scales: p.fill_(100.0) diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index 5f2a373ff..23efdce93 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -85,15 +85,8 @@ message SidRqvae { optional ContrastiveConfig contrastive_config = 16; // Reconstruction, commitment and (optional) contrastive losses are configured - // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the CLIP - // feature wiring above lives on this message. - - // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup - // from ModelConfig.feature_groups). May hold one or many content/side-info - // features; their concatenated dim is the encoder input_dim. Optional: when - // unset and exactly one feature group is declared, that group is used; with - // multiple groups it must be set explicitly. - optional string feature_group = 40; + // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the + // contrastive feature wiring above lives on this message. } message SidRqkmeans { @@ -114,11 +107,4 @@ message SidRqkmeans { // what FAISS subsamples to internally (default 256), so no training points // are wasted. optional uint32 train_sample_size = 6 [default = 0]; - - // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup - // from ModelConfig.feature_groups). May hold one or many content/side-info - // features; their concatenated dim is the K-Means dimension. Optional: when - // unset and exactly one feature group is declared, that group is used; with - // multiple groups it must be set explicitly. - optional string feature_group = 40; } diff --git a/tzrec/tests/configs/sid_rqvae_contrastive_mock.config b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config index 7c4154afd..0fbe9a4cc 100644 --- a/tzrec/tests/configs/sid_rqvae_contrastive_mock.config +++ b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config @@ -72,10 +72,8 @@ model_config { codebook: 16 forward_mode: "ste" kmeans_init: false - # Multiple feature groups -> name the main group explicitly. - feature_group: "deep" - # "pair" shares the encoder with "deep", so it must match its dim; - # "pair_flag" flags rows (>0.5 = pair). Objective: contrastive_loss below. + # "pair" shares the main encoder so must match its dim; "pair_flag" flags + # rows (>0.5 = pair). contrastive_config { pair_feature_group: "pair" pair_flag_feature_group: "pair_flag" From 60eac78182d839bb822edba31e23b47e6bb1cc2f Mon Sep 17 00:00:00 2001 From: shuqi <597191244@qq.com> Date: Thu, 25 Jun 2026 07:27:24 +0000 Subject: [PATCH 129/129] [test] SID: strengthen reservoir phase-2 assertion test_phase2_replacement asserted only ``(idx >= cap).any()``, which is near-tautological (it catches replacement being disabled outright but misses a badly-low accept probability). Require the phase-2 survivor count to exceed cap // 2 (the expected count is ~= cap). Co-Authored-By: Claude Opus 4.8 (1M context) --- tzrec/modules/sid/kmeans_quantize_test.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 59ec45a01..008b2b08b 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -151,9 +151,13 @@ def test_phase2_replacement(self) -> None: ) # All indices are valid stream positions. self.assertTrue((idx >= 0).all() and (idx < total).all()) - # Phase-2 replacement happened: at least one slot holds a row added - # after the reservoir filled (index >= cap). - self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + # Phase-2 replacement dominates the final sample: with a correct accept + # probability the expected post-fill survivor count is + # cap*(total-cap)/total ~= cap, so require well over half. A near-empty + # phase-2 count means the accept rate is broken (``.any()`` would only + # catch replacement being disabled outright). + n_phase2 = (idx >= cap).sum().item() + self.assertGreater(n_phase2, cap // 2, f"too few Phase-2 rows: {n_phase2}") def test_reset(self) -> None: """reset() drops the buffer and counters."""