diff --git a/tzrec/loss/sid_commitment_loss.py b/tzrec/loss/sid_commitment_loss.py new file mode 100644 index 000000000..6b2fd913f --- /dev/null +++ b/tzrec/loss/sid_commitment_loss.py @@ -0,0 +1,77 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidCommitmentLoss: VQ-VAE commitment loss for residual quantizers.""" + +from typing import Sequence + +import torch +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + + +class SidCommitmentLoss(_Loss): + """Commitment loss between the encoder output and the quantized vectors. + + Operates on a residual quantizer's per-layer cumulative quantized vectors + (the ``latents`` of :class:`~tzrec.modules.sid.types.ResidualQuantizerOutput`). + Both VQ-VAE directions are summed and averaged over the residual layers: + + - ``loss1`` = encoder-toward-quant (gradient flows into the encoder), w1 + - ``loss2`` = quant-toward-encoder (gradient flows into the codebook), w2 + + Args: + latent_weight (Sequence[float]): commitment weights ``[w1, w2]``. + Default: ``(1.0, 0.5)``. + commitment_type (str): distance, ``"l2"``, ``"l1"`` or ``"cos"``. + Default: ``"l2"``. + """ + + def __init__( + self, + latent_weight: Sequence[float] = (1.0, 0.5), + commitment_type: str = "l2", + ) -> None: + super().__init__() + if len(latent_weight) != 2: + raise ValueError( + f"latent_weight must have exactly 2 values [w1, w2], got " + f"{list(latent_weight)}" + ) + assert commitment_type in ("l2", "l1", "cos"), ( + f"commitment_type must be 'l2', 'l1' or 'cos', got {commitment_type!r}" + ) + self.commitment_w1, self.commitment_w2 = latent_weight + self.commitment_type = commitment_type + + def forward(self, encoder_out: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """Compute the commitment loss. + + Args: + encoder_out (Tensor): encoder output (the quantizer input), shape + (B, D). + latents (Tensor): per-layer cumulative quantized vectors, shape + (B, n_layers, D). + + Returns: + Tensor: scalar commitment loss (averaged over layers). + """ + x = encoder_out.unsqueeze(1) # (B, 1, D) -> broadcasts over layers + if self.commitment_type == "cos": + loss1 = (1 - F.cosine_similarity(x, latents.detach(), dim=-1)).mean() + loss2 = (1 - F.cosine_similarity(x.detach(), latents, dim=-1)).mean() + elif self.commitment_type == "l1": + loss1 = (x - latents.detach()).abs().mean() + loss2 = (x.detach() - latents).abs().mean() + else: # "l2" + loss1 = (x - latents.detach()).pow(2.0).mean() + loss2 = (x.detach() - latents).pow(2.0).mean() + return self.commitment_w1 * loss1 + self.commitment_w2 * loss2 diff --git a/tzrec/loss/sid_commitment_loss_test.py b/tzrec/loss/sid_commitment_loss_test.py new file mode 100644 index 000000000..1054a4c27 --- /dev/null +++ b/tzrec/loss/sid_commitment_loss_test.py @@ -0,0 +1,65 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.loss.sid_commitment_loss import SidCommitmentLoss + + +class SidCommitmentLossTest(unittest.TestCase): + """Tests for the standalone SidCommitmentLoss module.""" + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_branch_runs_and_backprops(self, commitment_type) -> None: + """Each commitment_type runs end-to-end; grad reaches both operands.""" + torch.manual_seed(0) + loss_fn = SidCommitmentLoss( + latent_weight=(1.0, 0.5), commitment_type=commitment_type + ) + B, L, D = 4, 3, 8 + encoder_out = torch.randn(B, D, requires_grad=True) + latents = torch.randn(B, L, D, requires_grad=True) + out = loss_fn(encoder_out, latents) + self.assertEqual(out.shape, ()) + self.assertTrue(torch.isfinite(out)) + out.backward() + # loss1 (encoder-toward-quant) feeds encoder_out; loss2 feeds latents. + self.assertIsNotNone(encoder_out.grad) + self.assertIsNotNone(latents.grad) + self.assertTrue(torch.isfinite(encoder_out.grad).all()) + + def test_latent_weight_wrong_length_raises(self) -> None: + """latent_weight must be exactly [w1, w2].""" + for bad in ([1.0], [1.0, 0.5, 0.25]): + with self.assertRaisesRegex(ValueError, "latent_weight"): + SidCommitmentLoss(latent_weight=bad) + + def test_invalid_commitment_type_raises(self) -> None: + """An unknown commitment_type is rejected.""" + with self.assertRaisesRegex(AssertionError, "commitment_type"): + SidCommitmentLoss(commitment_type="bogus") + + def test_weights_scale_the_two_directions(self) -> None: + """w1/w2 weight the encoder-toward-quant / quant-toward-encoder terms.""" + torch.manual_seed(0) + encoder_out = torch.randn(4, 8) + latents = torch.randn(4, 3, 8) + base = SidCommitmentLoss(latent_weight=(1.0, 0.5), commitment_type="l2") + zero = SidCommitmentLoss(latent_weight=(0.0, 0.0), commitment_type="l2") + self.assertGreater(base(encoder_out, latents).item(), 0.0) + self.assertEqual(zero(encoder_out, latents).item(), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/loss/sid_contrastive_loss.py b/tzrec/loss/sid_contrastive_loss.py new file mode 100644 index 000000000..53b5a6955 --- /dev/null +++ b/tzrec/loss/sid_contrastive_loss.py @@ -0,0 +1,194 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Masked InfoNCE contrastive loss with distributed all-gather support.""" + +import math +from typing import List, Optional + +import torch +import torch.distributed as dist +import torch.distributed.nn as dist_nn +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +# CLIP temperature init (reference CLIP: log(1 / 0.07)) and the cap applied +# before ``exp`` (reference CLIP clamps to ln(100)): an unbounded temperature +# would overflow to +Inf -> NaN grad -> corrupt param. +_LOGIT_SCALE_INIT = math.log(1 / 0.07) +_LOGIT_SCALE_MAX = math.log(100) + + +class SidContrastiveLoss(_Loss): + """Masked InfoNCE pair-contrastive loss for mixed (paired + non-paired) batches. + + Modality-agnostic: aligns two reconstructed "views" (``embed_a`` / ``embed_b``) + against each other and against their originals (``embed_a_ori`` / + ``embed_b_ori``) with three symmetric InfoNCE terms (self/ori/cl). In a mixed + batch, non-pair rows (``pair_mask=False``) must not contribute and must not + serve as negatives; row/column masks achieve this without data-dependent + branching (``torch.compile``-friendly). + + ``forward`` takes the four ``(B, dim)`` view embeddings plus the ``(B,)`` pair + mask and returns the scalar mean of the three contrastive terms. The three + temperatures (self/ori/cl) are learnable parameters owned by this module; + ``forward`` clamps (to <= ln(100)) and ``exp``s them. + """ + + def __init__(self) -> None: + super().__init__() + self.labels: Optional[torch.Tensor] = None + self.last_local_batch_size: Optional[int] = None + self._rank = dist.get_rank() if dist.is_initialized() else 0 + # Learnable contrastive temperatures, one per group (self / ori / cl); + # registered here so the loss module is self-contained. + self.logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self.logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + self.logit_scale_ori = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT) + + @staticmethod + def _scaled(logit_scale: torch.Tensor) -> torch.Tensor: + # Clamp before exp so a large temperature can't overflow to +Inf -> NaN. + return logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp() + + @staticmethod + def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """All-gather tensors across workers with gradient support. + + Single-process: returns the inputs unchanged. Multi-process: uses the + built-in differentiable ``torch.distributed.nn.functional.all_gather``, + so no custom ``autograd.Function`` is needed. + + Args: + tensors (List[Tensor]): list of tensors to gather. + + Returns: + List[Tensor]: gathered tensors, each (world_size * B, ...). + """ + if not dist.is_initialized() or dist.get_world_size() == 1: + return tensors + gathered: List[torch.Tensor] = [] + for tensor in tensors: + tensor_all = dist_nn.all_gather(tensor) # differentiable, per rank + gathered.append(torch.cat(tensor_all, dim=0)) + return gathered + + @staticmethod + def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor: + """All-gather bool mask across distributed workers.""" + if not dist.is_initialized() or dist.get_world_size() == 1: + return mask + mask_list = [torch.zeros_like(mask) for _ in range(dist.get_world_size())] + dist.all_gather(mask_list, mask) + return torch.cat(mask_list, dim=0) + + def _masked_cross_entropy( + self, + logits_a: torch.Tensor, + logits_b: torch.Tensor, + safe_labels: torch.Tensor, + pair_mask_f: torch.Tensor, + n_valid: torch.Tensor, + ) -> torch.Tensor: + """Masked cross-entropy on column-masked logits, row-masked average. + + Args: + logits_a: (B, B_global) column-masked logits (view-a branch). + logits_b: (B, B_global) column-masked logits (view-b branch). + safe_labels: (B,) labels with non-pair rows fallback to a safe col. + pair_mask_f: (B,) float pair mask (1.0 = pair row). + n_valid: scalar pair-row count, clamped to >= 1. + """ + ce_a = F.cross_entropy(logits_a, safe_labels, reduction="none") + ce_b = F.cross_entropy(logits_b, safe_labels, reduction="none") + # Backstop against a non-finite upstream logit (e.g. overflowed scale). + ce_a = torch.nan_to_num(ce_a, nan=0.0) + ce_b = torch.nan_to_num(ce_b, nan=0.0) + + return ((ce_a + ce_b) * pair_mask_f).sum() / (2 * n_valid) + + def forward( + self, + embed_a: torch.Tensor, + embed_b: torch.Tensor, + embed_a_ori: torch.Tensor, + embed_b_ori: torch.Tensor, + pair_mask: torch.Tensor, + ) -> torch.Tensor: + """Compute the masked pair-contrastive loss. + + Args: + embed_a: (B, dim) reconstructed (decoder) output of view a. + embed_b: (B, dim) reconstructed (decoder) output of view b. + embed_a_ori: (B, dim) original embedding of view a. + embed_b_ori: (B, dim) original embedding of view b. + pair_mask: (B,) bool, True = contrastive-pair sample. + + Returns: + Tensor: scalar mean of the three contrastive terms (self/ori/cl). + """ + logit_scale_self = self._scaled(self.logit_scale_self) + logit_scale_ori = self._scaled(self.logit_scale_ori) + logit_scale_cl = self._scaled(self.logit_scale_cl) + + local_batch_size = embed_a.size(0) + + # Labels carry the cross-rank offset, so refresh them on batch-size change. + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * self._rank + torch.arange( + local_batch_size, device=embed_a.device + ) + self.last_local_batch_size = local_batch_size + + embed_a = F.normalize(embed_a, dim=-1, p=2) + embed_b = F.normalize(embed_b, dim=-1, p=2) + + # One batched all-gather for all four operands (gradient-preserving). + embed_a_all, embed_b_all, embed_a_all_ori, embed_b_all_ori = ( + self._all_gather_with_grad([embed_a, embed_b, embed_a_ori, embed_b_ori]) + ) + + pair_mask_all = self._gather_bool_mask(pair_mask) + col_mask = (~pair_mask_all).unsqueeze(0) # (1, B_global) + + labels = self.labels + fallback = pair_mask.long().argmax() # first pair sample index + safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels)) + pair_mask_f = pair_mask.float() + n_valid = pair_mask_f.sum().clamp(min=1) + + # Three symmetric contrastive groups, each (scale, a-target, b-target): + # self: recon-a vs recon-b (vs the other recon view) + # ori: recon vs the counterpart original + # cl: recon vs its own-view original + groups = ( + (logit_scale_self, embed_b_all, embed_a_all), + (logit_scale_ori, embed_b_all_ori, embed_a_all_ori), + (logit_scale_cl, embed_a_all_ori, embed_b_all_ori), + ) + loss = embed_a.new_zeros(()) + for scale, a_target, b_target in groups: + logits_a = scale * embed_a @ a_target.t() + logits_b = scale * embed_b @ b_target.t() + # Fill masked columns with the LOGITS dtype's most negative finite + # value: below any real logit (masks like -inf) but finite, so an + # all-non-pair row yields a finite CE/grad instead of 0*NaN. Derive + # it from the logits dtype, not the embeddings': under autocast the + # matmul casts to bf16/fp16 and finfo(embed.dtype=fp32).min would + # overflow masked_fill on the lower-precision logits. + neg_fill = torch.finfo(logits_a.dtype).min + logits_a = logits_a.masked_fill(col_mask, neg_fill) + logits_b = logits_b.masked_fill(col_mask, neg_fill) + loss = loss + self._masked_cross_entropy( + logits_a, logits_b, safe_labels, pair_mask_f, n_valid + ) + return loss / 3 diff --git a/tzrec/loss/sid_contrastive_loss_test.py b/tzrec/loss/sid_contrastive_loss_test.py new file mode 100644 index 000000000..65044e557 --- /dev/null +++ b/tzrec/loss/sid_contrastive_loss_test.py @@ -0,0 +1,252 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tzrec.loss.sid_contrastive_loss import SidContrastiveLoss +from tzrec.utils import misc_util + + +class AllGatherWithGradTest(unittest.TestCase): + def test_single_process_identity(self) -> None: + a, b = torch.randn(3, 4), torch.randn(3, 4) + out = SidContrastiveLoss._all_gather_with_grad([a, b]) + self.assertIs(out[0], a) + self.assertIs(out[1], b) + + +class SidContrastiveLossTest(unittest.TestCase): + """Single-process tests for the masked pair-contrastive loss.""" + + def _features(self, B: int, D: int) -> dict: + torch.manual_seed(0) + return { + "embed_a": torch.randn(B, D, requires_grad=True), + "embed_b": torch.randn(B, D, requires_grad=True), + "embed_a_ori": torch.randn(B, D), + "embed_b_ori": torch.randn(B, D), + } + + def test_forward_all_pairs_finite(self) -> None: + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.ones(6, dtype=torch.bool) + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) + self.assertGreater(loss.item(), 0.0) + + def test_all_recon_mask_zero_loss(self) -> None: + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.zeros(6, dtype=torch.bool) # no pair rows + loss = loss_fn(**feats, pair_mask=mask) + # No pair rows -> masked average is exactly zero (and finite). + self.assertTrue(torch.isfinite(loss)) + self.assertAlmostEqual(loss.item(), 0.0, places=6) + + def test_all_recon_mask_finite_gradient(self) -> None: + # Regression: with float("-inf") column fill an all-recon batch produced + # a NaN gradient (0 * NaN) that survived the row mask. The finite fill + # must keep the backward finite (and zero, since no pair row contributes). + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.zeros(6, dtype=torch.bool) + loss_fn(**feats, pair_mask=mask).backward() + grad = feats["embed_a"].grad + self.assertIsNotNone(grad) + self.assertTrue(torch.isfinite(grad).all()) + self.assertAlmostEqual(grad.abs().sum().item(), 0.0, places=6) + + def test_backward_flows_to_embeddings(self) -> None: + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.ones(6, dtype=torch.bool) + loss_fn(**feats, pair_mask=mask).backward() + self.assertIsNotNone(feats["embed_a"].grad) + self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) + + def test_recon_columns_excluded_from_negatives(self) -> None: + """A recon row's embedding must not affect a pair row's loss. + + Recon rows are dropped as queries (row mask) AND their columns are + masked out of the negatives (col_mask). Perturbing the recon rows of + EVERY column operand — ``embed_b`` (the self group) and both + ``*_ori`` operands (the ori/cl groups) — must leave the pair rows' loss + unchanged; a dropped or inverted ``col_mask`` on any group would fail. + Distinct ``embed_a_ori`` / ``embed_b_ori`` so the ori/cl masking + is actually exercised (not hidden by a shared tensor). + """ + torch.manual_seed(0) + B, D = 4, 8 + img = torch.randn(B, D) + mask = torch.tensor([True, True, False, False]) # rows 2,3 are recon + + def feats(txt: torch.Tensor, txt_ori: torch.Tensor, img_ori: torch.Tensor): + return { + "embed_a": img, + "embed_b": txt, + "embed_a_ori": img_ori, + "embed_b_ori": txt_ori, + } + + txt, txt_ori, img_ori = (torch.randn(B, D) for _ in range(3)) + loss_fn = SidContrastiveLoss() + loss_fn.eval() + base = loss_fn(**feats(txt, txt_ori, img_ori), pair_mask=mask) + # Perturb ONLY the recon rows of every column operand that feeds negatives. + txt2, txt_ori2, img_ori2 = txt.clone(), txt_ori.clone(), img_ori.clone() + for t in (txt2, txt_ori2, img_ori2): + t[2:] = torch.randn(2, D) + after = loss_fn(**feats(txt2, txt_ori2, img_ori2), pair_mask=mask) + torch.testing.assert_close(base, after) + + def test_autocast_bf16_does_not_overflow_masked_fill(self) -> None: + # Regression: the column fill must come from the LOGITS dtype, not the + # (fp32) embeddings. Under autocast the matmul emits bf16, so + # finfo(fp32).min would raise "cannot be converted to BFloat16 without + # overflow" in masked_fill. A mixed pair/non-pair mask exercises the fill. + loss_fn = SidContrastiveLoss() + feats = self._features(6, 8) + mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) + with torch.autocast("cpu", dtype=torch.bfloat16): + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) + + def test_mask_holds_under_large_scale(self) -> None: + # The column fill is finfo.min (below any real logit) rather than a + # hardcoded -1e4, so masking holds even when the temperature is large and + # the *_ori operands are un-normalized (real logits can dwarf 1e4). The + # loss's internal clamp caps exp() at <= 100; loss/grad must stay finite. + loss_fn = SidContrastiveLoss() + with torch.no_grad(): + for p in ( + loss_fn.logit_scale_ori, + loss_fn.logit_scale_self, + loss_fn.logit_scale_cl, + ): + p.fill_(3000.0) + loss_fn.eval() + feats = self._features(6, 8) + feats["embed_a_ori"] = feats["embed_a_ori"] * 50 + feats["embed_b_ori"] = feats["embed_b_ori"] * 50 + mask = torch.tensor([1, 1, 1, 0, 0, 0], dtype=torch.bool) + loss = loss_fn(**feats, pair_mask=mask) + self.assertTrue(torch.isfinite(loss)) + loss_fn.train() + feats["embed_a"].grad = None + loss_fn(**feats, pair_mask=mask).backward() + self.assertTrue(torch.isfinite(feats["embed_a"].grad).all()) + + +# --- Multi-process tests for the contrastive distributed all-gather path. --- +# Validates ``_all_gather_with_grad`` (built on the differentiable +# ``torch.distributed.nn.functional.all_gather``) and ``SidContrastiveLoss`` across +# ranks. Uses NCCL on GPU when >=2 devices are available (the production path the +# reviewer cared about), else falls back to gloo/CPU, so it runs on a multi-GPU +# box and in CPU CI alike. + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _all_gather_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Each rank holds a distinct, rank-identifying tensor. + x = torch.full((2, 3), float(rank + 1), device=device, requires_grad=True) + gathered = SidContrastiveLoss._all_gather_with_grad([x])[0] + + # Forward: gathered is (world_size*2, 3); rank r contributes rows + # [2r : 2r+2] all equal to (r+1). + assert gathered.shape == (world_size * 2, 3), gathered.shape + for r in range(world_size): + block = gathered[2 * r : 2 * r + 2] + assert torch.allclose(block, torch.full_like(block, float(r + 1))), ( + f"rank{rank}: gathered block {r} wrong: {block}" + ) + + # Backward: identical scalar loss on every rank -> grad to every gathered + # element is 1; the differentiable all_gather sum-reduces across ranks, + # so the local input grad is world_size * ones. + gathered.sum().backward() + assert x.grad is not None, f"rank{rank}: no grad" + assert torch.isfinite(x.grad).all(), f"rank{rank}: non-finite grad" + expected = torch.full_like(x, float(world_size)) + assert torch.allclose(x.grad, expected), f"rank{rank}: grad {x.grad} != {expected}" + dist.destroy_process_group() + + +def _contrastive_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + torch.manual_seed(1234 + rank) + B, D = 4, 8 + feats = { + "embed_a": torch.randn(B, D, device=device, requires_grad=True), + "embed_b": torch.randn(B, D, device=device, requires_grad=True), + "embed_a_ori": torch.randn(B, D, device=device), + "embed_b_ori": torch.randn(B, D, device=device), + } + mask = torch.ones(B, dtype=torch.bool, device=device) + + loss_fn = SidContrastiveLoss().to(device) + loss = loss_fn(**feats, pair_mask=mask) + assert torch.isfinite(loss).all(), f"rank{rank}: non-finite loss" + assert loss.item() > 0.0, f"rank{rank}: loss not positive" + + loss.backward() + g = feats["embed_a"].grad + assert g is not None and torch.isfinite(g).all(), f"rank{rank}: bad grad" + dist.destroy_process_group() + + +def _run(target) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=target, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +class SidContrastiveDistTest(unittest.TestCase): + """2-rank tests for the contrastive distributed collectives.""" + + def test_all_gather_with_grad(self) -> None: + _run(_all_gather_worker) + + def test_masked_contrastive_loss(self) -> None: + _run(_contrastive_worker) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/loss/sid_recon_loss.py b/tzrec/loss/sid_recon_loss.py new file mode 100644 index 000000000..02724f875 --- /dev/null +++ b/tzrec/loss/sid_recon_loss.py @@ -0,0 +1,83 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidReconLoss: mask-aware RQ-VAE reconstruction loss (input vs. decoder).""" + +from typing import Optional + +import torch +from torch.nn import functional as F +from torch.nn.modules.loss import _Loss + +from tzrec.modules.utils import div_no_nan + + +class SidReconLoss(_Loss): + """Reconstruction loss for RQ-VAE: per-row distance reduced to a scalar. + + ``forward(x_hat, x, mask)`` computes the per-row distance for the configured + ``recon_type`` and averages it over the masked-in rows (all rows if ``mask`` + is None; the mixed recon+contrastive path passes ``recon_mask`` to score the + reconstruction-only rows). Registered as a ``_loss_modules`` entry alongside + the commitment / contrastive losses and, like them, returns a scalar. + + Args: + recon_type (str): the distance, ``"l2"`` (mse), ``"l1"`` or ``"cos"``. + Default: ``"l2"``. + """ + + def __init__(self, recon_type: str = "l2") -> None: + super().__init__() + if recon_type not in ("l2", "l1", "cos"): + raise ValueError( + f"recon_type must be 'l2', 'l1' or 'cos', got {recon_type!r}" + ) + self.recon_type = recon_type + + def _per_row(self, x_hat: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """Per-row reconstruction distance, shape (B,).""" + if self.recon_type == "l1": + return F.l1_loss(x_hat, x, reduction="none").mean(dim=-1) + if self.recon_type == "cos": + return 1 - F.cosine_similarity(x_hat, x, dim=-1) + return F.mse_loss(x_hat, x, reduction="none").mean(dim=-1) # "l2" + + @staticmethod + def _masked_mean( + per_sample: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Mean over the masked-in rows (all rows if ``mask`` is None). + + The masked mean divides by the valid-row count (``div_no_nan`` keeps an + empty mask at 0). No data-dependent branching -> ``torch.compile``-friendly. + """ + if mask is None: + return per_sample.mean() + mask = mask.float() + return div_no_nan((per_sample * mask).sum(), mask.sum()) + + def forward( + self, + x_hat: torch.Tensor, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Mask-aware reconstruction loss. + + Args: + x_hat (Tensor): reconstruction (decoder output), shape (B, D). + x (Tensor): the input it reconstructs, shape (B, D). + mask (Tensor, optional): per-row bool; rows to include (all if None). + + Returns: + Tensor: scalar reconstruction loss. + """ + return self._masked_mean(self._per_row(x_hat, x), mask) diff --git a/tzrec/loss/sid_recon_loss_test.py b/tzrec/loss/sid_recon_loss_test.py new file mode 100644 index 000000000..2d4a27e7a --- /dev/null +++ b/tzrec/loss/sid_recon_loss_test.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.loss.sid_recon_loss import SidReconLoss + + +class SidReconLossTest(unittest.TestCase): + """Tests for the reconstruction-loss module (per-row distance + reduction).""" + + def test_l2_is_per_row_mse(self) -> None: + d = SidReconLoss("l2")._per_row(torch.ones(3, 4), torch.zeros(3, 4)) + self.assertEqual(d.shape, (3,)) + torch.testing.assert_close(d, torch.ones(3)) # mean of 1^2 over dim -1 + + def test_l1_is_per_row_mae(self) -> None: + d = SidReconLoss("l1")._per_row(torch.ones(2, 5), torch.zeros(2, 5)) + torch.testing.assert_close(d, torch.ones(2)) + + def test_cos_is_one_minus_cosine(self) -> None: + x = torch.tensor([[1.0, 0.0]]) + # identical vectors -> cosine 1 -> distance 0 + d = SidReconLoss("cos")._per_row(x, x.clone()) + torch.testing.assert_close(d, torch.zeros(1), atol=1e-6, rtol=0) + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_each_type_scalar_and_backprops(self, recon_type) -> None: + x_hat = torch.randn(4, 6, requires_grad=True) + loss = SidReconLoss(recon_type)(x_hat, torch.randn(4, 6)) + self.assertEqual(loss.shape, ()) # forward reduces to a scalar + self.assertTrue(torch.isfinite(loss)) + loss.backward() # grad must flow back to the (decoder) input + self.assertIsNotNone(x_hat.grad) + + def test_unknown_type_raises(self) -> None: + with self.assertRaisesRegex(ValueError, "recon_type"): + SidReconLoss("nope") + + # --- masked-mean reduction (forward's mask handling) --- + + def test_no_mask_is_plain_mean(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + torch.testing.assert_close(SidReconLoss._masked_mean(x), x.mean()) + + def test_mask_averages_over_valid_rows_only(self) -> None: + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + mask = torch.tensor([True, False, True, False]) + torch.testing.assert_close( + SidReconLoss._masked_mean(x, mask), torch.tensor(2.0) + ) # (1+3)/2 + + def test_empty_mask_is_zero_not_nan(self) -> None: + out = SidReconLoss._masked_mean( + torch.tensor([1.0, 2.0, 3.0]), torch.zeros(3, dtype=torch.bool) + ) + self.assertEqual(out.item(), 0.0) + + def test_forward_applies_mask(self) -> None: + # l1 per-row of [[1,1],[2,2],[3,3],[4,4]] vs 0 is [1,2,3,4]; mask keeps 0,2. + x_hat = torch.tensor([[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [4.0, 4.0]]) + mask = torch.tensor([True, False, True, False]) + loss = SidReconLoss("l1")(x_hat, torch.zeros(4, 2), mask) + torch.testing.assert_close(loss, torch.tensor(2.0)) # (1+3)/2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/models/sid_model.py b/tzrec/models/sid_model.py index 8db468799..96fb4d77d 100644 --- a/tzrec/models/sid_model.py +++ b/tzrec/models/sid_model.py @@ -16,11 +16,16 @@ import torch import torchmetrics -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature +from tzrec.loss.sid_commitment_loss import SidCommitmentLoss +from tzrec.loss.sid_contrastive_loss import SidContrastiveLoss +from tzrec.loss.sid_recon_loss import SidReconLoss from tzrec.metrics.relative_l1 import RelativeL1 from tzrec.metrics.unique_ratio import UniqueRatio from tzrec.models.model import BaseModel +from tzrec.modules.embedding import EmbeddingGroup +from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.model_pb2 import ModelConfig @@ -30,11 +35,14 @@ class BaseSidModel(BaseModel): Factors the structure common to :class:`SidRqvae` (RQ-VAE) and :class:`SidRqkmeans` (residual K-Means): - - the shared config fields every SID proto carries — - ``embedding_feature_name`` (``_embedding_feature_name``), ``input_dim`` - (``_input_dim``), ``normalize_residuals`` (``_normalize_residuals``), + - the shared config fields every SID proto carries — ``feature_group`` + (``_feature_group``), ``normalize_residuals`` (``_normalize_residuals``), and the per-layer ``codebook`` (``_n_embed_list`` / ``_n_layers``), - - reading the item-embedding feature out of ``Batch.dense_features``, + - building the main input through the framework's :class:`EmbeddingGroup` + (:meth:`init_input` / :meth:`build_input`), so a SID model consumes the + same grouped/concatenated feature tensor as every other model and + ``_input_dim`` is *derived* from the group's total dimension (supporting + multiple content embeddings + side-info in one group), - the eval metrics every SID model reports — reconstruction ``mse`` and ``unique_sid_ratio`` (mean per-batch unique-SID ratio, a diversity proxy). @@ -63,48 +71,117 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) cfg = self._model_config - # Config fields shared by every SID model (present on each SID proto - # message): the item-embedding feature, the input dimension, the - # residual-normalization toggle, and the per-layer codebook. - self._embedding_feature_name = cfg.embedding_feature_name - self._input_dim = cfg.input_dim self._normalize_residuals = cfg.normalize_residuals if not cfg.codebook: raise ValueError("codebook must be set, e.g. [256, 256, 256]") self._n_embed_list = list(cfg.codebook) - # Fail fast: a zero codebook entry / input_dim==0 only errors opaquely - # deep inside faiss, after the whole training pass. + # Fail fast: a zero entry only errors opaquely deep in faiss later. if any(k < 1 for k in self._n_embed_list): raise ValueError( f"every codebook entry must be >= 1, got {self._n_embed_list}" ) - if self._input_dim < 1: - raise ValueError(f"input_dim must be >= 1, got {self._input_dim}") self._n_layers = len(self._n_embed_list) - def _extract_feature( - self, batch: Batch, feature_name: Optional[str] = None - ) -> torch.Tensor: - """Extract a named dense feature from ``Batch.dense_features``. + self.init_input() + self._feature_group = self.embedding_group.group_names()[0] + self._input_dim = self.embedding_group.group_total_dim(self._feature_group) + if self._input_dim < 1: + raise ValueError( + f"feature group {self._feature_group!r} has total dim " + f"{self._input_dim}; it must be >= 1" + ) - Args: - batch (Batch): input batch data. - feature_name (str, optional): feature name to extract. - Defaults to ``self._embedding_feature_name``. - """ - if feature_name is None: - feature_name = self._embedding_feature_name - kt = batch.dense_features[BASE_DATA_GROUP] - return kt[feature_name] + def init_input(self) -> None: + """Build the :class:`EmbeddingGroup` from features + feature groups.""" + self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) + + def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Build grouped input features: ``{group_name: (B, group_total_dim)}``.""" + return self.embedding_group(batch) def init_loss(self) -> None: - """Initialize loss modules. + """Initialize SID loss modules from ``ModelConfig.losses``. - SID models compute their losses internally and pass them through - ``predictions``; there is no external loss module to register. + Each ``LossConfig`` sets one ``sid_loss`` oneof variant (a reconstruction + loss, the commitment loss, or the contrastive loss). Mirrors ``RankModel``: + the config drives what is registered here, and :meth:`loss` computes them + from ``predictions``. All three are registered as ``_loss_modules`` entries. """ - pass + for loss_cfg in self._base_model_config.losses: + self._init_sid_loss_impl(loss_cfg) + + def _init_sid_loss_impl(self, loss_cfg: LossConfig) -> None: + """Register the loss module for one ``sid_loss`` config.""" + loss_type = loss_cfg.WhichOneof("sid_loss") + if loss_type == "recon_loss": + self._loss_modules["recon_loss"] = SidReconLoss( + loss_cfg.recon_loss.recon_type + ) + elif loss_type == "commitment_loss": + cfg = loss_cfg.commitment_loss + latent_weight = list(cfg.latent_weight) if cfg.latent_weight else (1.0, 0.5) + self._loss_modules["commitment_loss"] = SidCommitmentLoss( + latent_weight=latent_weight, + commitment_type=cfg.commitment_type, + ) + elif loss_type == "contrastive_loss": + # The contrastive module owns its learnable temperatures. + self._loss_modules["contrastive_loss"] = SidContrastiveLoss() + else: + raise ValueError( + f"LossConfig for a SID model must set a sid_loss variant, " + f"got {loss_type!r}" + ) + + def loss( + self, predictions: Dict[str, torch.Tensor], batch: Batch + ) -> Dict[str, torch.Tensor]: + """Compute the configured SID losses from ``predictions``. + + Args: + predictions (dict): a dict of predicted result (the raw tensors the + losses consume — ``x_hat``/``recon_target`` for reconstruction, + ``encoder_out``/``latents`` for commitment, and the contrastive + embeds). + batch (Batch): input batch data. + + Return: + losses (dict): a dict of loss tensor keyed by the sid_loss variant. + """ + losses: Dict[str, torch.Tensor] = {} + for loss_cfg in self._base_model_config.losses: + losses.update(self._sid_loss_impl(predictions, loss_cfg)) + return losses + + def _sid_loss_impl( + self, predictions: Dict[str, torch.Tensor], loss_cfg: LossConfig + ) -> Dict[str, torch.Tensor]: + """Compute one ``sid_loss`` term from ``predictions``.""" + loss_type = loss_cfg.WhichOneof("sid_loss") + if loss_type == "recon_loss": + loss = self._loss_modules["recon_loss"]( + predictions["x_hat"], + predictions["recon_target"], + predictions.get("recon_mask"), + ) + return {"recon_loss": loss} + elif loss_type == "commitment_loss": + loss = self._loss_modules["commitment_loss"]( + predictions["encoder_out"], predictions["latents"] + ) + return {"commitment_loss": loss} + elif loss_type == "contrastive_loss": + loss = self._loss_modules["contrastive_loss"]( + predictions["embed_a"], + predictions["embed_b"], + predictions["embed_a_ori"], + predictions["embed_b_ori"], + predictions["pair_mask"], + ) + return {"contrastive_loss": loss} + else: + raise ValueError(f"unsupported sid_loss variant: {loss_type!r}") def init_metric(self) -> None: """Initialize the eval metrics shared by all SID models. @@ -130,13 +207,17 @@ def update_metric( batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - """Update eval metrics from the reconstruction + the re-extracted input. + """Update eval metrics from the reconstruction vs. the input embedding. ``predictions["x_hat"]`` is the model's reconstruction of the input embedding (the centroid sum for RQ-KMeans, the decoder output for - RQ-VAE). Subclasses expose it only when it is meaningful, so a - not-yet-fitted model omits it and this logs nothing. The target - embedding is re-extracted from ``batch`` (it is an input, not an output). + RQ-VAE); ``predictions["recon_target"]`` is the input it reconstructs. + Subclasses expose both only when meaningful, so a not-yet-fitted model + omits them and this logs nothing. (Reading the target from + ``predictions`` avoids a second ``build_input`` pass over ``batch``.) + For the mixed contrastive path the reconstruction is scored only on the + non-pair rows (``recon_mask``), matching the masked training recon loss + so the eval mse/rel_loss stay comparable to the optimized objective. Args: predictions (dict): a dict of predicted result. @@ -146,7 +227,11 @@ def update_metric( if "x_hat" not in predictions: return recon = predictions["x_hat"] - embedding = self._extract_feature(batch) + embedding = predictions["recon_target"] + recon_mask = predictions.get("recon_mask") + if recon_mask is not None: + recon = recon[recon_mask] + embedding = embedding[recon_mask] self._metric_modules["mse"].update(recon, embedding) self._metric_modules["rel_loss"].update(recon, embedding) self._metric_modules["unique_sid_ratio"].update(predictions["codes"]) @@ -158,7 +243,8 @@ def update_train_metric( ) -> None: """Update train-path metric state. - Default is a no-op: K-Means has no train-time codes, so only models - with a meaningful train signal (RQ-VAE) override this. + Default no-op: the current SID models report metrics at eval (after the + codebook is fit / the decoder is trained), not during training. A + subclass with a meaningful train-time signal may override this. """ return diff --git a/tzrec/models/sid_rqkmeans.py b/tzrec/models/sid_rqkmeans.py index 59b05af41..933481fa2 100644 --- a/tzrec/models/sid_rqkmeans.py +++ b/tzrec/models/sid_rqkmeans.py @@ -121,7 +121,7 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: Return: predictions (dict): a dict of predicted result. """ - embedding = self._extract_feature(batch) + embedding = self.build_input(batch)[self._feature_group] # Training: reservoir-sample only; codes are dummy until the fit. if self.is_train: @@ -139,11 +139,9 @@ def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: "codes": codes, } - # Expose the centroid-sum reconstruction (``x_hat``) for update_metric - # only once fitted — pre-fit it is all-zeros, so omitting it skips the - # eval metrics. (Meaningful only with normalize_residuals=False.) if self.is_eval and self._quantizer.is_fitted: predictions["x_hat"] = quantized + predictions["recon_target"] = embedding return predictions diff --git a/tzrec/models/sid_rqkmeans_test.py b/tzrec/models/sid_rqkmeans_test.py index 0b68fefa6..128b5acbc 100644 --- a/tzrec/models/sid_rqkmeans_test.py +++ b/tzrec/models/sid_rqkmeans_test.py @@ -17,12 +17,37 @@ from torchrec import KeyedTensor from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import create_features from tzrec.models.sid_rqkmeans import SidRqkmeans -from tzrec.protos import model_pb2 +from tzrec.protos import feature_pb2, model_pb2 from tzrec.protos.models import sid_model_pb2 from tzrec.utils.state_dict_util import init_parameters +def _features_and_groups(input_dim: int): + """Real ``item_emb`` raw feature + the ``deep`` group it feeds. + + SID models consume the framework's EmbeddingGroup (built from these), and + derive the K-Means dimension from the ``deep`` group's total dim — so real + features + feature_groups are required, as in every other model test. + """ + feature_cfgs = [ + feature_pb2.FeatureConfig( + raw_feature=feature_pb2.RawFeature( + feature_name="item_emb", value_dim=input_dim + ) + ) + ] + groups = [ + model_pb2.FeatureGroupConfig( + group_name="deep", + feature_names=["item_emb"], + group_type=model_pb2.FeatureGroupType.DEEP, + ) + ] + return create_features(feature_cfgs), groups + + def _batch_from_rows(rows: torch.Tensor) -> Batch: """Wrap explicit ``item_emb`` rows in a minimal Batch.""" dense_feature = KeyedTensor.from_tensor_list(keys=["item_emb"], tensors=[rows]) @@ -58,26 +83,23 @@ def _create_model( normalize_residuals=False, train_sample_size=0, ): - """Build a SidRqkmeans on CPU with params initialized. - - SID models read the item-embedding dense feature directly from the - batch and do not consume feature_groups, so none is set. - """ + """Build a SidRqkmeans on CPU with params initialized.""" n_embed_list = codebook if codebook is not None else [16] * n_layers faiss_kwargs = sid_model_pb2.FaissKmeansConfig( niter=niter, verbose=False, seed=1234 ) cfg = sid_model_pb2.SidRqkmeans( - input_dim=input_dim, codebook=n_embed_list, normalize_residuals=normalize_residuals, faiss_kmeans_kwargs=faiss_kwargs, - embedding_feature_name="item_emb", train_sample_size=train_sample_size, ) + features, feature_groups = _features_and_groups(input_dim) model = SidRqkmeans( - model_config=model_pb2.ModelConfig(sid_rqkmeans=cfg), - features=[], + model_config=model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqkmeans=cfg + ), + features=features, labels=[], ) init_parameters(model, device=torch.device("cpu")) @@ -119,9 +141,9 @@ def test_init_raises_on_zero_codebook_entry(self) -> None: with self.assertRaisesRegex(ValueError, "codebook entry must be >= 1"): self._create_model(codebook=[16, 0]) - def test_init_raises_on_zero_input_dim(self) -> None: - """input_dim < 1 fails fast at construction.""" - with self.assertRaisesRegex(ValueError, "input_dim must be >= 1"): + def test_init_raises_on_zero_dim_feature_group(self) -> None: + """A feature group with total dim 0 fails fast (derived input_dim < 1).""" + with self.assertRaisesRegex(ValueError, "must be >= 1"): self._create_model(input_dim=0) def test_predict_collects_buffer(self) -> None: @@ -238,7 +260,7 @@ def test_normalize_residuals_end_to_end(self) -> None: self.assertTrue((codes >= 0).all() and (codes < 16).all()) def test_eval_and_inference_predict_contract(self) -> None: - """Eval (post-fit) exposes codes + x_hat; inference is codes-only.""" + """Eval (post-fit) exposes codes + x_hat + recon_target; infer codes-only.""" try: import faiss # noqa: F401 except ImportError: @@ -251,12 +273,12 @@ def test_eval_and_inference_predict_contract(self) -> None: model.predict(_make_batch(B, input_dim)) model.on_train_end() - # Eval mode (fitted): the reconstruction is exposed as ``x_hat`` for - # update_metric; the input embedding is re-extracted from the batch - # there, not threaded through predictions. + # Eval mode (fitted): the reconstruction (``x_hat``) and its target + # (``recon_target``) are both exposed for update_metric, so it scores + # without a second build_input pass over the batch. model.eval() eval_preds = model.predict(_make_batch(B, input_dim)) - self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat"}) + self.assertEqual(set(eval_preds.keys()), {"codes", "x_hat", "recon_target"}) # Inference (serving) mode: codes-only contract. model.set_is_inference(True) diff --git a/tzrec/models/sid_rqvae.py b/tzrec/models/sid_rqvae.py new file mode 100644 index 000000000..732c520b4 --- /dev/null +++ b/tzrec/models/sid_rqvae.py @@ -0,0 +1,256 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SidRqvae: SID generation model using RQ-VAE (Encoder + VQ + Decoder). + +End-to-end differentiable training. The reconstruction, commitment and optional +contrastive losses are configured via ``ModelConfig.losses`` (the +``LossConfig`` ``sid_loss`` oneof) and computed centrally in +:meth:`BaseSidModel.loss`; :meth:`predict` only produces the raw tensors those +losses consume. The encoder/decoder and residual vector quantizer live directly +on the model — there is no intermediate ``RQVAE`` module wrapper. +""" + +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from tzrec.datasets.utils import Batch +from tzrec.features.feature import BaseFeature +from tzrec.models.sid_model import BaseSidModel +from tzrec.modules.mlp import MLP +from tzrec.modules.sid.residual_vector_quantizer import ( + ResidualVectorQuantizer, +) +from tzrec.modules.sid.types import ResidualQuantizerOutput +from tzrec.protos.model_pb2 import ModelConfig +from tzrec.utils.config_util import config_to_kwargs +from tzrec.utils.logging_util import logger + + +class SidRqvae(BaseSidModel): + """SID generation model using RQ-VAE (Encoder + VQ + Decoder). + + Encoder/Decoder are configurable-depth MLPs built from ``hidden_dims``: + Encoder: input_dim -> hidden_dims[0] -> ... -> embed_dim + Decoder: embed_dim -> ... -> hidden_dims[0] -> input_dim + (ReLU between hidden layers; the decoder mirrors the encoder.) + + Losses are config-driven (``ModelConfig.losses`` / ``sid_loss`` oneof). When a + ``contrastive_loss`` is configured, ``predict`` runs a dual (paired) path and + the masked contrastive loss is applied to the contrastive-pair rows. + + Args: + model_config (ModelConfig): an instance of ModelConfig. + features (list): list of features. + labels (list): list of label names. + sample_weights (list): sample weight names. + """ + + def __init__( + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + super().__init__(model_config, features, labels, sample_weights, **kwargs) + + cfg = self._model_config # SidRqvae proto message + + self._init_contrastive() + + embed_dim = cfg.embed_dim + # Fail fast (parity with BaseSidModel's codebook/input_dim checks): a zero + # dim only errors opaquely deep in nn.Linear/Embedding otherwise. + if embed_dim < 1: + raise ValueError(f"embed_dim must be >= 1, got {embed_dim}") + hidden_dims = ( + list(cfg.hidden_dims) if cfg.hidden_dims else [self._input_dim // 2] + ) + if any(h < 1 for h in hidden_dims): + raise ValueError(f"every hidden_dims entry must be >= 1, got {hidden_dims}") + + sinkhorn_cfg = config_to_kwargs(cfg.sinkhorn_config) + + # MLP activates its last layer; the trailing bare Linear keeps the + # latent / reconstruction unbounded. + self._encoder = nn.Sequential( + MLP(self._input_dim, hidden_units=hidden_dims), + nn.Linear(hidden_dims[-1], embed_dim), + ) + self._decoder = nn.Sequential( + MLP(embed_dim, hidden_units=list(reversed(hidden_dims))), + nn.Linear(hidden_dims[0], self._input_dim), + ) + + self._quantizer = ResidualVectorQuantizer( + embed_dim=embed_dim, + n_layers=self._n_layers, + n_embed=self._n_embed_list, + forward_mode=cfg.forward_mode, + normalize_residuals=self._normalize_residuals, + distance_type=cfg.distance_type, + rotation_trick=cfg.rotation_trick, + kmeans_init=cfg.kmeans_init, + use_sinkhorn=sinkhorn_cfg["enabled"], + sinkhorn_iters=sinkhorn_cfg["iters"], + sinkhorn_epsilon=sinkhorn_cfg["epsilon"], + ) + + logger.info( + "SidRqvae init: input_dim=%d, embed_dim=%d, hidden_dims=%s, " + "n_layers=%d, n_embed=%s, use_contrastive=%s", + self._input_dim, + embed_dim, + hidden_dims, + self._n_layers, + self._n_embed_list, + self._use_contrastive, + ) + + def _init_contrastive(self) -> None: + """Read and validate the pair-contrastive wiring (``contrastive_config``). + + Sets ``_use_contrastive`` and the paired / pair-flag group names, and + enforces: ``contrastive_config`` (structure) and a ``contrastive_loss`` + entry (objective) are set together; the paired group exists and matches + ``input_dim`` (it shares the encoder); the pair-flag group is a single + dim-1 raw flag. Must run after ``super().__init__()`` — it needs + ``embedding_group`` / ``_input_dim``. + """ + cfg = self._model_config + self._pair_feature_group = None + self._pair_flag_feature_group = None + self._use_contrastive = cfg.HasField("contrastive_config") + has_contrastive_obj = any( + lc.WhichOneof("sid_loss") == "contrastive_loss" + for lc in self._base_model_config.losses + ) + if self._use_contrastive != has_contrastive_obj: + raise ValueError( + "contrastive_config (model structure) and a contrastive_loss entry " + "in losses (the objective) must be set together; got " + f"contrastive_config={self._use_contrastive}, " + f"contrastive_loss={has_contrastive_obj}" + ) + if not self._use_contrastive: + return + self._pair_feature_group = cfg.contrastive_config.pair_feature_group + self._pair_flag_feature_group = cfg.contrastive_config.pair_flag_feature_group + for grp in (self._pair_feature_group, self._pair_flag_feature_group): + if not self.embedding_group.has_group(grp): + raise ValueError( + f"contrastive group {grp!r} is not in model_config.feature_groups" + f" {self.embedding_group.group_names()}" + ) + pair_dim = self.embedding_group.group_total_dim(self._pair_feature_group) + if pair_dim != self._input_dim: + raise ValueError( + f"pair_feature_group {self._pair_feature_group!r} has total " + f"dim {pair_dim}, but it is encoded by the same encoder as the " + f"main feature_group (dim {self._input_dim}); the two must match" + ) + flag_dim = self.embedding_group.group_total_dim(self._pair_flag_feature_group) + if flag_dim != 1: + raise ValueError( + f"pair_flag_feature_group {self._pair_flag_feature_group!r} must " + f"be a single dim-1 raw flag, got total dim {flag_dim}" + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode. (B, input_dim) -> (B, embed_dim).""" + return self._encoder(x) + + def _decode(self, z_q: torch.Tensor) -> torch.Tensor: + """Decode. (B, embed_dim) -> (B, input_dim).""" + return self._decoder(z_q) + + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: + """Predict the model. + + Returns the raw tensors the configured losses consume (computed in + :meth:`BaseSidModel.loss`); inference emits codes only. + + Args: + batch (Batch): input batch data. + + Return: + predictions (dict): a dict of predicted result. + """ + grouped = self.build_input(batch) + embedding = grouped[self._feature_group] + if self._is_inference: + return {"codes": self._quantizer.get_codes(self._encode(embedding))} + if self._use_contrastive: + return self._predict_mixed(grouped) + return self._predict_rqvae(embedding) + + def _rqvae_pass( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, ResidualQuantizerOutput, torch.Tensor]: + """One RQ-VAE pass over ``x``: encode -> quantize -> decode. + + Returns the encoder output ``z_e`` (commitment operand), the quantizer + output ``quant`` (cluster_ids / latents / quantized_embeddings) and the + decoded reconstruction ``x_hat``. + """ + z_e = self._encode(x) + quant = self._quantizer(z_e) + return z_e, quant, self._decode(quant.quantized_embeddings) + + def _predict_rqvae(self, embedding: torch.Tensor) -> Dict[str, torch.Tensor]: + """Standard RQ-VAE: a single reconstruction pass.""" + z_e, quant, x_hat = self._rqvae_pass(embedding) + return { + "codes": quant.cluster_ids, + "x_hat": x_hat, + "recon_target": embedding, + "encoder_out": z_e, + "latents": quant.latents, + } + + def _predict_mixed( + self, grouped: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Mixed recon + contrastive: dual path over the main + paired groups. + + ``encoder_out`` / ``latents`` stack both paths so the commitment loss + averages over them; ``recon_mask`` (= non-pair rows) restricts the recon + loss to reconstruction-only rows. + + Args: + grouped (dict): the EmbeddingGroup output (group name -> tensor). + """ + embedding = grouped[self._feature_group] + fea2 = grouped[self._pair_feature_group] + is_pair_raw = grouped[self._pair_flag_feature_group] + pair_mask = is_pair_raw.view(is_pair_raw.shape[0], -1)[:, 0] > 0.5 + + z_e1, quant1, x_hat1 = self._rqvae_pass(embedding) + z_e2, quant2, x_hat2 = self._rqvae_pass(fea2) + + return { + "codes": quant1.cluster_ids, + "x_hat": x_hat1, + "recon_target": embedding, + "recon_mask": ~pair_mask, + "encoder_out": torch.cat([z_e1, z_e2], dim=0), + "latents": torch.cat([quant1.latents, quant2.latents], dim=0), + # generic contrastive operands (view a = main, view b = paired): + "embed_a": x_hat1, + "embed_b": x_hat2, + "embed_a_ori": embedding, + "embed_b_ori": fea2, + "pair_mask": pair_mask, + } diff --git a/tzrec/models/sid_rqvae_test.py b/tzrec/models/sid_rqvae_test.py new file mode 100644 index 000000000..292fe552e --- /dev/null +++ b/tzrec/models/sid_rqvae_test.py @@ -0,0 +1,461 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized +from torchrec import KeyedTensor + +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch +from tzrec.features.feature import create_features +from tzrec.models.sid_rqvae import SidRqvae +from tzrec.protos import feature_pb2, loss_pb2, model_pb2 +from tzrec.protos.models import sid_model_pb2 +from tzrec.utils.state_dict_util import init_parameters + + +def _features_and_groups( + input_dim: int, use_contrastive: bool = False, pair_emb_dim: int = None, flag_dim=1 +): + """Real raw features + feature groups for a SID model. + + Mirrors how every other model test wires inputs: ``create_features`` builds + the ``BaseFeature`` objects and ``feature_groups`` (consumed by the model's + :class:`EmbeddingGroup`) name the main ``deep`` group — plus, for the + contrastive path, the paired group and the per-row pair-flag group. + ``pair_emb_dim`` (default: match ``input_dim``) sizes the paired group and + ``flag_dim`` (default 1) sizes the pair-flag group, so a test can + deliberately mismatch either. + """ + + def _raw(name: str, dim: int) -> feature_pb2.FeatureConfig: + return feature_pb2.FeatureConfig( + raw_feature=feature_pb2.RawFeature(feature_name=name, value_dim=dim) + ) + + def _deep(group_name: str, feature_name: str) -> model_pb2.FeatureGroupConfig: + return model_pb2.FeatureGroupConfig( + group_name=group_name, + feature_names=[feature_name], + group_type=model_pb2.FeatureGroupType.DEEP, + ) + + feature_cfgs = [_raw("item_emb", input_dim)] + groups = [_deep("deep", "item_emb")] + if use_contrastive: + feature_cfgs += [ + _raw("pair_emb", pair_emb_dim if pair_emb_dim is not None else input_dim), + _raw("is_pair", flag_dim), + ] + groups += [_deep("pair", "pair_emb"), _deep("pair_flag", "is_pair")] + return create_features(feature_cfgs), groups + + +def _make_batch(batch_size: int, input_dim: int) -> Batch: + """Create a minimal Batch with the ``item_emb`` dense feature.""" + dense_feature = KeyedTensor.from_tensor_list( + keys=["item_emb"], tensors=[torch.randn(batch_size, input_dim)] + ) + return Batch( + dense_features={BASE_DATA_GROUP: dense_feature}, + sparse_features={}, + labels={}, + ) + + +def _recon_loss_cfg(recon_type: str = "l2") -> loss_pb2.LossConfig: + """A LossConfig with a recon_loss term of the given recon_type.""" + lc = loss_pb2.LossConfig() + lc.recon_loss.recon_type = recon_type + return lc + + +def _commitment_cfg( + latent_weight=(1.0, 0.5), commitment_type="l2" +) -> loss_pb2.LossConfig: + lc = loss_pb2.LossConfig() + lc.commitment_loss.latent_weight.extend(latent_weight) + lc.commitment_loss.commitment_type = commitment_type + return lc + + +def _contrastive_cfg() -> loss_pb2.LossConfig: + # The contrastive objective marker (empty); the paired-feature wiring lives + # on the model proto (SidRqvae.contrastive_config), set in _create_model. + lc = loss_pb2.LossConfig() + lc.contrastive_loss.SetInParent() + return lc + + +class SidRqvaeTest(unittest.TestCase): + """Tests for SidRqvae model.""" + + def _create_model( + self, + use_contrastive=False, + input_dim=32, + embed_dim=8, + n_layers=2, + recon="l2", + ): + """Helper to create a SidRqvae model with config-driven losses.""" + n_embed_list = [16] * n_layers + sid_rqvae_cfg = sid_model_pb2.SidRqvae( + embed_dim=embed_dim, + codebook=n_embed_list, + forward_mode="ste", + kmeans_init=False, + ) + losses = [_recon_loss_cfg(recon), _commitment_cfg()] + if use_contrastive: + sid_rqvae_cfg.contrastive_config.pair_feature_group = "pair" + sid_rqvae_cfg.contrastive_config.pair_flag_feature_group = "pair_flag" + losses.append(_contrastive_cfg()) + + # Real features + feature_groups: input_dim is derived from the group. + features, feature_groups = _features_and_groups(input_dim, use_contrastive) + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=sid_rqvae_cfg, losses=losses + ) + model = SidRqvae(model_config=model_config, features=features, labels=[]) + init_parameters(model, device=torch.device("cpu")) + return model + + def _contrastive_batch(self, B, input_dim, is_pair): + return Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "pair_emb", "is_pair"], + tensors=[ + torch.randn(B, input_dim), + torch.randn(B, input_dim), + is_pair, + ], + ) + }, + sparse_features={}, + labels={}, + ) + + def test_rqvae_train_mode(self) -> None: + """Test SidRqvae in train mode: predict -> loss -> metric.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.train() + model.init_loss() + model.init_metric() + + batch = _make_batch(B, input_dim) + predictions = model.predict(batch) + + # predict() returns only the raw tensors the losses consume. + self.assertIn("codes", predictions) + self.assertIn("x_hat", predictions) + self.assertIn("encoder_out", predictions) + self.assertIn("latents", predictions) + self.assertEqual(predictions["codes"].shape[0], B) + + # loss() computes the configured recon + commitment terms. + losses = model.loss(predictions, batch) + self.assertIn("recon_loss", losses) + self.assertIn("commitment_loss", losses) + + total_loss = sum(losses.values()) + self.assertTrue(total_loss.requires_grad) + + model.update_metric(predictions, batch, losses) + metrics = model.compute_metric() + self.assertIn("mse", metrics) + self.assertIn("unique_sid_ratio", metrics) + + def test_rqvae_eval_mode(self) -> None: + """Test SidRqvae in eval mode: predict returns the recon fields.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.eval() + + predictions = model.predict(_make_batch(B, input_dim)) + + # Eval mode (not inference) exposes x_hat for the metric + losses. + self.assertIn("codes", predictions) + self.assertIn("x_hat", predictions) + self.assertIn("encoder_out", predictions) + self.assertIn("latents", predictions) + + def test_rqvae_inference_mode(self) -> None: + """Test SidRqvae in inference mode: only codes returned.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.eval() + model.set_is_inference(True) + + predictions = model.predict(_make_batch(B, input_dim)) + self.assertIn("codes", predictions) + self.assertNotIn("x_hat", predictions) + self.assertNotIn("latents", predictions) + + def test_rqvae_contrastive_mode(self) -> None: + """Test SidRqvae with the mixed recon + contrastive path.""" + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.train() + model.init_loss() + + is_pair = torch.zeros(B, 1) + is_pair[B // 2 :] = 1.0 # second half are contrastive pairs + batch = self._contrastive_batch(B, input_dim, is_pair) + + predictions = model.predict(batch) + self.assertIn("codes", predictions) + self.assertIn("x_hat", predictions) + self.assertIn("embed_a", predictions) + self.assertEqual(predictions["codes"].shape[0], B) + + losses = model.loss(predictions, batch) + self.assertIn("recon_loss", losses) + self.assertIn("commitment_loss", losses) + self.assertIn("contrastive_loss", losses) + + total_loss = sum(losses.values()) + self.assertTrue(total_loss.requires_grad) + total_loss.backward() + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_rqvae_contrastive_all_recon(self) -> None: + """Mixed mode, all-recon batch: contrastive term 0, recon term > 0.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.train() + model.init_loss() + + batch = self._contrastive_batch(B, input_dim, torch.zeros(B, 1)) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["contrastive_loss"].item(), 0.0) + self.assertGreater(losses["recon_loss"].item(), 0.0) + + def test_rqvae_contrastive_all_pair(self) -> None: + """Mixed mode, all-pair batch: recon term 0, contrastive term > 0.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.train() + model.init_loss() + + batch = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["recon_loss"].item(), 0.0) + self.assertGreater(losses["contrastive_loss"].item(), 0.0) + + def test_rqvae_backward(self) -> None: + """Test that backward pass works without errors.""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim) + model.train() + model.init_loss() + + batch = _make_batch(B, input_dim) + losses = model.loss(model.predict(batch), batch) + sum(losses.values()).backward() + + has_grad = any( + p.grad is not None and p.grad.abs().sum() > 0 for p in model.parameters() + ) + self.assertTrue(has_grad) + + def test_commitment_latent_weight_wrong_length_raises(self) -> None: + """A commitment_loss with a bad latent_weight length fails in init_loss.""" + features, feature_groups = _features_and_groups(32) + for bad in ([1.0], [1.0, 0.5, 0.25]): + cfg = sid_model_pb2.SidRqvae( + embed_dim=8, codebook=[16, 16], kmeans_init=False + ) + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, + sid_rqvae=cfg, + losses=[_commitment_cfg(latent_weight=bad)], + ) + model = SidRqvae(model_config=model_config, features=features, labels=[]) + with self.assertRaisesRegex(ValueError, "latent_weight"): + model.init_loss() + + def test_pair_feature_group_dim_mismatch_raises(self) -> None: + """A paired group whose dim != the main group fails fast at init. + + The paired feature is encoded by the same encoder as the main input, so + a dim mismatch would otherwise crash with an opaque matmul shape error + on the first contrastive forward — not at construction. + """ + features, feature_groups = _features_and_groups( + 32, use_contrastive=True, pair_emb_dim=16 + ) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flag" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] + ) + with self.assertRaisesRegex(ValueError, "must match"): + SidRqvae(model_config=model_config, features=features, labels=[]) + + def test_pair_flag_group_must_be_dim_1(self) -> None: + """A pair-flag group with dim != 1 fails fast (would mis-route rows).""" + features, feature_groups = _features_and_groups( + 32, use_contrastive=True, flag_dim=3 + ) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flag" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] + ) + with self.assertRaisesRegex(ValueError, "dim-1 raw flag"): + SidRqvae(model_config=model_config, features=features, labels=[]) + + def test_contrastive_group_missing_raises(self) -> None: + """A typo'd contrastive group name fails fast at init, not on forward.""" + features, feature_groups = _features_and_groups(32, use_contrastive=True) + cfg = sid_model_pb2.SidRqvae(embed_dim=8, codebook=[16, 16], kmeans_init=False) + cfg.contrastive_config.pair_feature_group = "pair" + cfg.contrastive_config.pair_flag_feature_group = "pair_flagTYPO" + model_config = model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg, losses=[_contrastive_cfg()] + ) + with self.assertRaisesRegex(ValueError, "not in model_config.feature_groups"): + SidRqvae(model_config=model_config, features=features, labels=[]) + + def test_eval_metric_masks_contrastive_pair_rows(self) -> None: + """Contrastive eval mse/rel_loss score only the non-pair (recon) rows. + + Training masks the recon loss to non-pair rows; update_metric must apply + the same ``recon_mask`` so the eval metric stays comparable (pair rows, + which the decoder is not trained to reconstruct, must not dilute it). + """ + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.eval() + model.init_metric() + + # All-pair batch: recon_mask selects zero rows, so mse observes none. + all_pair = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) + model.update_metric(model.predict(all_pair), all_pair) + self.assertEqual(model._metric_modules["mse"].total.item(), 0.0) + + # A recon (non-pair) batch then contributes rows. + all_recon = self._contrastive_batch(B, input_dim, torch.zeros(B, 1)) + model.update_metric(model.predict(all_recon), all_recon) + self.assertGreater(model._metric_modules["mse"].total.item(), 0.0) + + def test_pair_flag_drives_routing_not_equality(self) -> None: + """The is_pair flag, not bit-exact equality, drives routing. + + Build a batch where ``pair_emb == item_emb`` numerically but + ``is_pair=1``: rows must route to the contrastive branch (under the old + bit-exact logic they would have been silently relabeled recon). + """ + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.train() + model.init_loss() + + item_emb = torch.randn(B, input_dim) + batch = Batch( + dense_features={ + BASE_DATA_GROUP: KeyedTensor.from_tensor_list( + keys=["item_emb", "pair_emb", "is_pair"], + tensors=[item_emb, item_emb.clone(), torch.ones(B, 1)], + ) + }, + sparse_features={}, + labels={}, + ) + losses = model.loss(model.predict(batch), batch) + self.assertEqual(losses["recon_loss"].item(), 0.0) + self.assertGreater(losses["contrastive_loss"].item(), 0.0) + + @parameterized.expand( + [ + ("omitted", None, True), # no sinkhorn_config -> on by default + ("enabled_true", True, True), + ("enabled_false", False, False), # was hard-coded True before + ] + ) + def test_sinkhorn_config(self, _name, enabled, expect_use_sinkhorn) -> None: + """``sinkhorn_config.enabled`` (or its omission) drives layer.use_sinkhorn.""" + cfg = sid_model_pb2.SidRqvae( + embed_dim=8, + codebook=[16, 16], + forward_mode="ste", + kmeans_init=False, + ) + if enabled is not None: + cfg.sinkhorn_config.CopyFrom(sid_model_pb2.SinkhornConfig(enabled=enabled)) + features, feature_groups = _features_and_groups(32) + model = SidRqvae( + model_config=model_pb2.ModelConfig( + feature_groups=feature_groups, sid_rqvae=cfg + ), + features=features, + labels=[], + ) + init_parameters(model, device=torch.device("cpu")) + for layer in model._quantizer.layers: + self.assertEqual(layer.use_sinkhorn, expect_use_sinkhorn) + + @parameterized.expand([("l2",), ("l1",), ("cos",)]) + def test_recon_type_branch(self, recon_type) -> None: + """Each recon_type runs end-to-end (grad flows through the decoder).""" + B, input_dim = 4, 32 + model = self._create_model(input_dim=input_dim, recon=recon_type) + model.train() + model.init_loss() + losses = model.loss( + model.predict(_make_batch(B, input_dim)), _make_batch(B, input_dim) + ) + recon = losses["recon_loss"] + self.assertTrue(torch.isfinite(recon), f"{recon_type} not finite") + recon.backward() # grad must flow through the decoder + + def test_logit_scale_clamped_prevents_overflow(self) -> None: + """A raw logit_scale far above ln(100) must not overflow. + + The clamp caps ``exp()`` so the contrastive loss and the parameter + gradient stay finite; without it, ``exp(large)`` -> +Inf -> a NaN + gradient that permanently corrupts the parameter. + """ + B, input_dim = 8, 32 + model = self._create_model(input_dim=input_dim, use_contrastive=True) + model.train() + model.init_loss() + # The temperatures live on the contrastive module that owns the clamp. + contrastive = model._loss_modules["contrastive_loss"] + scales = ( + contrastive.logit_scale_self, + contrastive.logit_scale_cl, + contrastive.logit_scale_ori, + ) + with torch.no_grad(): + for p in scales: + p.fill_(100.0) + + batch = self._contrastive_batch(B, input_dim, torch.ones(B, 1)) + losses = model.loss(model.predict(batch), batch) + self.assertTrue(torch.isfinite(losses["contrastive_loss"])) + sum(losses.values()).backward() + for p in scales: + self.assertIsNotNone(p.grad) + self.assertTrue(torch.isfinite(p.grad).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/kmeans_quantize.py b/tzrec/modules/sid/kmeans_quantize.py index 6eb5b940a..b17b11e22 100644 --- a/tzrec/modules/sid/kmeans_quantize.py +++ b/tzrec/modules/sid/kmeans_quantize.py @@ -19,9 +19,11 @@ * :class:`ReservoirSampler` — bounded uniform stream sample (Vitter Algorithm R) that :class:`~tzrec.models.sid_rqkmeans.SidRqkmeans` fills during training to feed the one-shot FAISS fit. +* :func:`faiss_kmeans_fit` — the shared one-layer FAISS fit behind both SID + residual-K-Means loops (RQ-VAE warm-start and offline RQ-K-Means). """ -from typing import Optional +from typing import Any, Dict, Optional import torch @@ -30,6 +32,56 @@ from tzrec.utils.logging_util import logger +def faiss_kmeans_fit( + x: Any, + dim: int, + n_clusters: int, + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> Any: + """Train one ``faiss.Kmeans(dim, n_clusters)`` on ``x`` and return it. + + The shared one-layer FAISS fit behind both SID residual-K-Means loops (the + RQ-VAE warm-start and the offline RQ-K-Means); the caller reads + ``km.centroids`` and assigns via ``km.index.search``. Strips a ``gpu`` kwarg + (faiss honors it and would move the fit to GPU, breaking the CPU-only + contract) and guards ``N >= n_clusters`` before faiss's opaque C++ throw. + ``x`` may be a numpy array or a torch tensor. + + Args: + x: data points, shape (N, dim) — numpy array or torch tensor. + dim (int): feature dimension. + n_clusters (int): number of centroids (codebook size). + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans``. + + Returns: + The trained ``faiss.Kmeans`` (read ``.centroids`` / ``.index``). + + Raises: + ImportError: if ``faiss`` is not installed. + RuntimeError: if ``x`` has fewer than ``n_clusters`` rows. + """ + try: + import faiss + except ImportError as e: + raise ImportError( + "faiss is required for SID residual K-Means. Install via " + "`pip install faiss-cpu` or `pip install faiss-gpu`." + ) from e + + # Copy + drop any `gpu` key: faiss.Kmeans honors it and would move the fit + # to GPU, breaking the CPU-only contract (the caller's dict stays untouched). + kwargs = dict(faiss_kmeans_kwargs or {}) + kwargs.pop("gpu", None) + n = int(x.shape[0]) + if n < n_clusters: + raise RuntimeError( + f"need >= {n_clusters} points to fit the codebook, got N={n}" + ) + km = faiss.Kmeans(dim, n_clusters, **kwargs) + km.train(x) + return km + + class ReservoirSampler: """Bounded uniform sample of a stream (Vitter Algorithm R). @@ -202,7 +254,7 @@ def _load_from_state_dict( ) @torch.no_grad() - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: """Assign points to the nearest centroid and gather them. Uses ``torch.cdist`` (L2); argmin is invariant to the monotonic sqrt, @@ -210,11 +262,9 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: (measure zero for real embeddings), where either centroid is valid. Before the FAISS fit (uninitialized) this returns all-zero codes + embeddings so the residual walk stays a no-op and the model is callable. - ``temperature`` is unused (no soft assignment). Args: x (Tensor): data points, shape (B, D). - temperature (float): unused. Returns: QuantizeOutput: ``ids`` (B,) and ``embeddings`` (B, D). @@ -222,7 +272,9 @@ def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: if not self.is_initialized: ids = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) return QuantizeOutput(embeddings=torch.zeros_like(x), ids=ids) - ids = torch.cdist(x, self.centroids).argmin(dim=-1) + # Match x to the centroid dtype (as load_centroids_ does): cdist rejects + # mismatched dtypes, so a non-fp32 input would otherwise raise. + ids = torch.cdist(x.to(self.centroids.dtype), self.centroids).argmin(dim=-1) return QuantizeOutput(embeddings=self.centroids[ids], ids=ids) def get_codebook_embeddings(self) -> torch.Tensor: diff --git a/tzrec/modules/sid/kmeans_quantize_test.py b/tzrec/modules/sid/kmeans_quantize_test.py index 2f2883562..008b2b08b 100644 --- a/tzrec/modules/sid/kmeans_quantize_test.py +++ b/tzrec/modules/sid/kmeans_quantize_test.py @@ -16,9 +16,33 @@ from tzrec.modules.sid.kmeans_quantize import ( KMeansQuantizeLayer, ReservoirSampler, + faiss_kmeans_fit, ) +class FaissKmeansFitTest(unittest.TestCase): + """Tests for the shared one-layer FAISS fit primitive.""" + + def _require_faiss(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + + def test_fit_returns_trained_kmeans(self) -> None: + self._require_faiss() + torch.manual_seed(0) + # numpy input (the RQ-VAE call path; no faiss torch-utils needed). + x = torch.randn(200, 6).numpy() + km = faiss_kmeans_fit(x, 6, 8, {"niter": 5, "seed": 1, "verbose": False}) + self.assertEqual(tuple(km.centroids.shape), (8, 6)) + + def test_raises_on_too_few_points(self) -> None: + self._require_faiss() + with self.assertRaisesRegex(RuntimeError, "need >= 8 points"): + faiss_kmeans_fit(torch.randn(4, 6).numpy(), 6, 8) + + class KMeansQuantizeLayerTest(unittest.TestCase): """Tests for the single KMeansQuantizeLayer.""" @@ -127,9 +151,13 @@ def test_phase2_replacement(self) -> None: ) # All indices are valid stream positions. self.assertTrue((idx >= 0).all() and (idx < total).all()) - # Phase-2 replacement happened: at least one slot holds a row added - # after the reservoir filled (index >= cap). - self.assertTrue((idx >= cap).any(), "no Phase-2 replacement occurred") + # Phase-2 replacement dominates the final sample: with a correct accept + # probability the expected post-fill survivor count is + # cap*(total-cap)/total ~= cap, so require well over half. A near-empty + # phase-2 count means the accept rate is broken (``.any()`` would only + # catch replacement being disabled outright). + n_phase2 = (idx >= cap).sum().item() + self.assertGreater(n_phase2, cap // 2, f"too few Phase-2 rows: {n_phase2}") def test_reset(self) -> None: """reset() drops the buffer and counters.""" diff --git a/tzrec/modules/sid/quantize_layer.py b/tzrec/modules/sid/quantize_layer.py index e7f344fda..5c712589a 100644 --- a/tzrec/modules/sid/quantize_layer.py +++ b/tzrec/modules/sid/quantize_layer.py @@ -39,7 +39,7 @@ def __init__(self, n_embed: int, embed_dim: int) -> None: self.embed_dim = embed_dim @abstractmethod - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: """Assign ``x`` (B, D) to the codebook, returning codes + embeddings.""" raise NotImplementedError diff --git a/tzrec/modules/sid/quantize_layer_test.py b/tzrec/modules/sid/quantize_layer_test.py index 28eb4849b..4be9648a5 100644 --- a/tzrec/modules/sid/quantize_layer_test.py +++ b/tzrec/modules/sid/quantize_layer_test.py @@ -31,7 +31,7 @@ def __init__(self, n_embed: int, embed_dim: int) -> None: n_embed, embed_dim ) - def quantize(self, x: torch.Tensor, temperature: float = 1.0) -> QuantizeOutput: + def quantize(self, x: torch.Tensor) -> QuantizeOutput: dist = torch.cdist(x, self._codebook) ids = dist.argmin(dim=-1) return QuantizeOutput(embeddings=self.lookup(ids), ids=ids) diff --git a/tzrec/modules/sid/residual_kmeans_quantizer.py b/tzrec/modules/sid/residual_kmeans_quantizer.py index 11b06951c..e94594be8 100644 --- a/tzrec/modules/sid/residual_kmeans_quantizer.py +++ b/tzrec/modules/sid/residual_kmeans_quantizer.py @@ -17,13 +17,12 @@ from typing import Dict, List, Optional, Tuple, Union -import faiss import faiss.contrib.torch_utils # noqa: F401 (registers torch tensor I/O) import torch from torch import nn from torch.nn import functional as F -from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer +from tzrec.modules.sid.kmeans_quantize import KMeansQuantizeLayer, faiss_kmeans_fit from tzrec.modules.sid.residual_quantizer import ResidualQuantizer from tzrec.utils.logging_util import logger @@ -83,7 +82,6 @@ def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Nearest-centroid assignment for one layer (delegates to the layer). @@ -93,13 +91,12 @@ def _quantize_layer( Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): unused (no soft assignment). Returns: codes (Tensor): cluster indices, shape (B,). quantized (Tensor): selected centroids, shape (B, D). """ - out = self.layers[layer_idx].quantize(residual, temperature) + out = self.layers[layer_idx].quantize(residual) return out.ids, out.embeddings def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -202,10 +199,6 @@ def train_offline( # breaks that invariant, so clone then. x0 = x.clone() if (verbose and self.normalize_residuals) else None - # CPU-only fit (SidRqkmeans refuses CUDA). Drop any stale ``gpu`` kwarg - # so a faiss-gpu build can't target an absent GPU. - kwargs = dict(self.faiss_kmeans_kwargs) - kwargs.pop("gpu", None) if verbose: logger.info( "[ResidualKMeansQuantizer] fitting %d-layer codebook on CPU " @@ -224,15 +217,17 @@ def train_offline( # Fresh Kmeans per layer so each can use its own K (non-uniform # codebooks). - kmeans = faiss.Kmeans( - self.embed_dim, self.n_embed_list[layer_idx], **kwargs + km = faiss_kmeans_fit( + x, + self.embed_dim, + self.n_embed_list[layer_idx], + self.faiss_kmeans_kwargs, ) - kmeans.train(x) - centroids = torch.as_tensor(kmeans.centroids, dtype=torch.float32) + centroids = torch.as_tensor(km.centroids, dtype=torch.float32) for start in range(0, N, SEARCH_CHUNK): end = min(start + SEARCH_CHUNK, N) - _, idx = kmeans.index.search(x[start:end], 1) + _, idx = km.index.search(x[start:end], 1) idx = torch.as_tensor(idx).reshape(-1).long() q = centroids[idx] # (chunk, D) out[start:end] += q diff --git a/tzrec/modules/sid/residual_quantizer.py b/tzrec/modules/sid/residual_quantizer.py index 6b80f2e33..2c0c704c7 100644 --- a/tzrec/modules/sid/residual_quantizer.py +++ b/tzrec/modules/sid/residual_quantizer.py @@ -98,17 +98,15 @@ def _quantize_layer( self, layer_idx: int, residual: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Assign one layer's codes and look up its quantized vector. Backend primitive behind the residual walk (encode-direction mirror of - :meth:`_lookup_code`). ``temperature`` is used only by the VQ backend. + :meth:`_lookup_code`). Args: layer_idx (int): quantization layer index. residual (Tensor): current residual, shape (B, D). - temperature (float): Gumbel-Softmax temperature (VQ only). Returns: codes (Tensor): per-layer cluster ids, shape (B,). @@ -119,7 +117,6 @@ def _quantize_layer( def _residual_pass( self, input: torch.Tensor, - temperature: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """Shared residual walk: per-layer assign, subtract, accumulate. @@ -129,7 +126,6 @@ def _residual_pass( Args: input (Tensor): input embeddings, shape (B, D). - temperature (float): forwarded to :meth:`_quantize_layer`. Returns: cluster_ids (Tensor): stacked codes, shape (B, n_layers). @@ -144,7 +140,7 @@ def _residual_pass( for i in range(self.n_layers): if self.normalize_residuals: residual = F.normalize(residual, dim=-1) - codes, quantized = self._quantize_layer(i, residual, temperature) + codes, quantized = self._quantize_layer(i, residual) all_codes.append(codes) aggregated = aggregated + quantized cumulative.append(aggregated) diff --git a/tzrec/modules/sid/residual_quantizer_test.py b/tzrec/modules/sid/residual_quantizer_test.py index c94cc545d..0f8c3bcf9 100644 --- a/tzrec/modules/sid/residual_quantizer_test.py +++ b/tzrec/modules/sid/residual_quantizer_test.py @@ -72,7 +72,7 @@ def __init__(self, embed_dim, n_layers, n_embed=5, normalize_residuals=False): ] ) - def _quantize_layer(self, layer_idx, residual, temperature=1.0): + def _quantize_layer(self, layer_idx, residual): codes = (residual.detach() @ self.books[layer_idx].t()).argmax(dim=-1) return codes, self.books[layer_idx][codes] diff --git a/tzrec/modules/sid/residual_vector_quantizer.py b/tzrec/modules/sid/residual_vector_quantizer.py new file mode 100644 index 000000000..560bea7bd --- /dev/null +++ b/tzrec/modules/sid/residual_vector_quantizer.py @@ -0,0 +1,354 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ResidualVectorQuantizer: multi-layer residual VQ with gradient training.""" + +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid.kmeans_quantize import faiss_kmeans_fit +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer +from tzrec.modules.sid.types import ( + QuantizeForwardMode, + ResidualQuantizerOutput, +) +from tzrec.modules.sid.vector_quantize import VectorQuantizeLayer +from tzrec.utils.logging_util import logger + + +@torch.no_grad() +def faiss_residual_kmeans( + samples: torch.Tensor, + n_clusters_list: List[int], + faiss_kmeans_kwargs: Optional[Dict] = None, +) -> List[torch.Tensor]: + """Residual K-Means warm-start via FAISS, one pass per layer. + + Clusters ``samples``, subtracts each point's assigned centroid, and repeats + on the residual per layer. Seeds the RQ-VAE codebook (via + :meth:`ResidualVectorQuantizer.init_embed_`) from the first training batch. + The fit is always CPU (host fp32 numpy copy); centroids return on + ``samples.device`` — no faiss-gpu build needed. + + Args: + samples (Tensor): data points, shape (N, D). + n_clusters_list (List[int]): per-layer cluster counts. + faiss_kmeans_kwargs (Dict|None): extra kwargs for ``faiss.Kmeans`` + (e.g. ``{'niter': 10, 'seed': 123}``). + + Returns: + List[Tensor]: per-layer centroids ``[(K0, D), ...]`` on samples.device. + + Raises: + ImportError: if ``faiss`` is not installed. + RuntimeError: if a layer has fewer points than its cluster count. + """ + device = samples.device + _, D = samples.shape + x = samples.detach().cpu().float().numpy().copy() + + res_centers: List[torch.Tensor] = [] + for n_clusters in n_clusters_list: + km = faiss_kmeans_fit(x, D, n_clusters, faiss_kmeans_kwargs) + centroids = km.centroids.copy() + res_centers.append(torch.from_numpy(centroids).to(device)) + _, idx = km.index.search(x, 1) + x -= centroids[idx.ravel()] + return res_centers + + +class ResidualVectorQuantizer(ResidualQuantizer): + """Multi-layer residual vector quantization. + + Each layer quantizes the residual from the previous layer: + residual_0 = input + for each layer i: + (optionally) residual_i = L2_normalize(residual_i) + code_i, quantized_i = quantize(residual_i) + residual_{i+1} = residual_i - quantized_i + output = sum of all quantized_i + + Semantic ID = (code_0, code_1, ..., code_{n_layers-1}) + + Args: + embed_dim (int): dimension of input embeddings. + n_layers (int): number of quantization layers. + n_embed (int|List[int]): codebook size per layer. Default: 256. + forward_mode (str): VQ forward mode ('ste'|'gumbel_softmax'). + Default: 'ste'. + normalize_residuals (bool): L2-normalize residuals before each + quantization layer. Default: False. + distance_type (str): distance metric, 'l2' or 'cosine'. Default: 'l2'. + rotation_trick (bool): use rotation trick for improved STE + gradient estimation (arXiv:2410.06424). Default: False. + kmeans_init (bool): use residual K-Means codebook initialization + on first forward. Default: False. + use_sinkhorn (bool): Sinkhorn uniform assignment. Default: True. + sinkhorn_iters (int): Sinkhorn iterations. Default: 5. + sinkhorn_epsilon (float): Sinkhorn sharpness. Default: 10.0. + gumbel_temperature (float): Gumbel-Softmax temperature. Default: 1.0. + """ + + _FORWARD_MODE_MAP = { + "gumbel_softmax": QuantizeForwardMode.GUMBEL_SOFTMAX, + "ste": QuantizeForwardMode.STE, + } + + def __init__( + self, + embed_dim: int, + n_layers: int, + n_embed: Union[int, List[int]] = 256, + forward_mode: str = "ste", + normalize_residuals: bool = False, + distance_type: str = "l2", + rotation_trick: bool = False, + kmeans_init: bool = False, + use_sinkhorn: bool = True, + sinkhorn_iters: int = 5, + sinkhorn_epsilon: float = 10.0, + gumbel_temperature: float = 1.0, + ) -> None: + super().__init__(embed_dim, n_layers, n_embed, normalize_residuals) + self.rotation_trick = rotation_trick + + self.register_buffer("initted", torch.tensor([not kmeans_init])) + + if forward_mode not in self._FORWARD_MODE_MAP: + raise ValueError( + f"Unsupported forward_mode '{forward_mode}', " + f"choose from {list(self._FORWARD_MODE_MAP.keys())}" + ) + mode_enum = self._FORWARD_MODE_MAP[forward_mode] + self._forward_mode = mode_enum + is_gumbel = mode_enum == QuantizeForwardMode.GUMBEL_SOFTMAX + if is_gumbel and use_sinkhorn: + logger.warning("gumbel_softmax: disabling incompatible use_sinkhorn.") + use_sinkhorn = False + if is_gumbel and rotation_trick: + logger.warning("gumbel_softmax: rotation_trick has no effect; ignoring.") + + self.layers = nn.ModuleList( + [ + VectorQuantizeLayer( + embed_dim=embed_dim, + n_embed=self.n_embed_list[i], + forward_mode=mode_enum, + distance_type=distance_type, + use_sinkhorn=use_sinkhorn, + sinkhorn_iters=sinkhorn_iters, + sinkhorn_epsilon=sinkhorn_epsilon, + gumbel_temperature=gumbel_temperature, + ) + for i in range(n_layers) + ] + ) + + logger.info( + "ResidualVectorQuantizer init: embed_dim=%d, n_layers=%d, " + "n_embed=%s, forward_mode=%s, normalize_residuals=%s, " + "distance_type=%s, rotation_trick=%s, kmeans_init=%s, " + "use_sinkhorn=%s, sinkhorn_iters=%d, sinkhorn_epsilon=%s", + embed_dim, + n_layers, + n_embed, + forward_mode, + normalize_residuals, + distance_type, + rotation_trick, + kmeans_init, + use_sinkhorn, + sinkhorn_iters, + sinkhorn_epsilon, + ) + + @torch.jit.ignore + @torch.no_grad() + def init_embed_(self, data: torch.Tensor) -> None: + """Initialize codebook weights via FAISS residual K-Means. + + Runs once (kmeans_init=True, not yet initialized), seeding from the first + training batch. Under DDP the fit happens on rank 0 and is broadcast, so + every rank starts from the same codebook (averaging per-rank centroids + would mix permutation-misaligned clusters into a near-random start). + + Args: + data (Tensor): input data, shape (B, D). + """ + if self.initted: + return + + is_ddp = dist.is_initialized() and dist.get_world_size() > 1 + # The fit runs on rank 0 only, then broadcasts. faiss needs N >= max(K), + # so a too-small rank-0 first batch would raise on rank 0 while the other + # ranks block forever on the centroid broadcast. Broadcast rank 0's verdict + # first so every rank aborts together with a clear error instead. + max_k = max(self.n_embed_list) + enough = torch.tensor([1 if data.shape[0] >= max_k else 0], device=data.device) + if is_ddp: + dist.broadcast(enough, src=0) + if enough.item() == 0: + raise RuntimeError( + f"kmeans_init: rank-0 first training batch has fewer rows than the " + f"largest codebook ({max_k}); raise batch_size or disable kmeans_init." + ) + + if (not is_ddp) or dist.get_rank() == 0: + # TODO(follow-up): accumulate samples across multiple batches for the + # warm-start fit instead of seeding from only the first training batch. + centers = faiss_residual_kmeans( + data, + self.n_embed_list, + {"niter": 10, "seed": 123, "verbose": False}, + ) + else: + centers = [ + torch.empty(k, self.embed_dim, dtype=torch.float32, device=data.device) + for k in self.n_embed_list + ] + if is_ddp: + for c in centers: + dist.broadcast(c, src=0) + + for i, layer in enumerate(self.layers): + layer.embedding.weight.data.copy_(centers[i]) + + self.initted.fill_(True) + + @staticmethod + def _apply_rotation_trick( + x: torch.Tensor, + quant: torch.Tensor, + ) -> torch.Tensor: + """Apply rotation trick for improved STE gradient estimation. + + Implements equation 4.2 from https://arxiv.org/abs/2410.06424. + Replaces standard STE with a Householder reflection that rotates + the gradient direction from x toward quant. + + Args: + x (Tensor): original input with gradient, shape (B, D). + quant (Tensor): quantized output (will be detached), + shape (B, D). + + Returns: + Tensor: rotated output with gradient flowing through x. + """ + quant_detached = quant.detach() + x_detached = x.detach() + + quant_norms = torch.linalg.vector_norm(quant_detached, dim=-1).unsqueeze(1) + x_norms = torch.linalg.vector_norm(x_detached, dim=-1).unsqueeze(1) + lambda_ = quant_norms / (x_norms + 1e-8) + + x_hat = x_detached / (x_norms + 1e-8) + quant_hat = quant_detached / (quant_norms + 1e-8) + + normalized_sum = F.normalize(x_hat + quant_hat, p=2, dim=1) + + x_unsq = x.unsqueeze(1) + + # Eq 4.2: Householder reflection + sum_projection = ( + x_unsq @ normalized_sum.unsqueeze(2) @ normalized_sum.unsqueeze(1) + ) + rescaled_embeddings = x_unsq @ x_hat.unsqueeze(2) @ quant_hat.unsqueeze(1) + return lambda_ * ( + x_unsq - 2 * sum_projection + 2 * rescaled_embeddings + ).squeeze(1) + + def _quantize_layer( + self, + layer_idx: int, + residual: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize one layer's residual via its ``VectorQuantizeLayer`` layer. + + STE: raw codebook vector (STE applied on the aggregate in :meth:`forward`). + Gumbel: the soft embedding (carries grad directly). + + Args: + layer_idx (int): quantization layer index. + residual (Tensor): current residual, shape (B, D). + + Returns: + ids (Tensor): per-layer cluster ids, shape (B,). + emb (Tensor): the raw codebook vector (STE/eval) or the soft + embedding (Gumbel), with grad, shape (B, D). + """ + out = self.layers[layer_idx].quantize(residual) + return out.ids, out.embeddings + + def forward( + self, + input: torch.Tensor, + ) -> ResidualQuantizerOutput: + """Forward the multi-layer residual quantization. + + Encoder gradient by ``forward_mode``: STE walks the DETACHED input and + re-attaches grad via the aggregate STE below (codebook trains via the + commitment loss); Gumbel's soft assignment is differentiable, so it walks + the LIVE input and skips the aggregate STE. + + Args: + input (Tensor): input embeddings, shape (B, D). + + Returns: + ResidualQuantizerOutput: (cluster_ids, quantized_embeddings, + latents). + """ + if self.training: + self.init_embed_(input) + + train_gumbel = ( + self.training and self._forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + ) + + walk_input = input if train_gumbel else input.detach() + cluster_ids, aggregated_quants, cumulative = self._residual_pass(walk_input) + + latents = torch.stack(cumulative, dim=1) + + quants_trunc = aggregated_quants + if self.training and not train_gumbel: + if self.rotation_trick: + quants_trunc = self._apply_rotation_trick(input, quants_trunc) + else: + quants_trunc = input + (quants_trunc - input).detach() + + return ResidualQuantizerOutput( + cluster_ids=cluster_ids, + quantized_embeddings=quants_trunc, + latents=latents, + ) + + @torch.no_grad() + def get_codebook_embeddings(self, layer_idx: int) -> torch.Tensor: + """Get codebook embedding weights for a specific layer. + + Detached read-only view for export/inspection (the layer's weight is a + grad leaf, needed by the training ``lookup`` path). + + Args: + layer_idx (int): index of the quantization layer. + + Returns: + Tensor: codebook weights, shape (n_embed, embed_dim). + """ + return self.layers[layer_idx].get_codebook_embeddings().detach() + + def _lookup_code(self, layer_idx: int, code_idx: torch.Tensor) -> torch.Tensor: + """Look up codebook vectors via the layer's embedding table.""" + return self.layers[layer_idx].lookup(code_idx) diff --git a/tzrec/modules/sid/residual_vector_quantizer_test.py b/tzrec/modules/sid/residual_vector_quantizer_test.py new file mode 100644 index 000000000..ec6a062c9 --- /dev/null +++ b/tzrec/modules/sid/residual_vector_quantizer_test.py @@ -0,0 +1,291 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from tzrec.modules.sid.residual_quantizer import ResidualQuantizer +from tzrec.modules.sid.residual_vector_quantizer import ( + ResidualVectorQuantizer, + faiss_residual_kmeans, +) +from tzrec.modules.sid.types import ResidualQuantizerOutput +from tzrec.utils import misc_util + + +class GumbelResidualVQTest(unittest.TestCase): + """Gumbel-Softmax forward_mode: grad reaches encoder + codebook (not STE).""" + + def test_default_gumbel_config_disables_sinkhorn(self) -> None: + # use_sinkhorn defaults True; gumbel must auto-disable it (not crash). + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=3, + n_embed=16, + forward_mode="gumbel_softmax", + use_sinkhorn=True, + kmeans_init=False, + ) + self.assertTrue(all(not layer.use_sinkhorn for layer in rvq.layers)) + + def test_gumbel_grad_flows_via_soft_assignment(self) -> None: + # The fix: the gradient from the reconstruction path (no commitment + # loss) must reach BOTH the encoder input and the codebook through the + # soft gumbel embedding. Under the old code it reached neither (the soft + # embedding was discarded), so gumbel silently trained like STE. + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=3, + n_embed=16, + forward_mode="gumbel_softmax", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + z = torch.randn(32, 8, requires_grad=True) + rvq(z).quantized_embeddings.sum().backward() + self.assertIsNotNone(z.grad) + self.assertGreater(z.grad.abs().sum().item(), 0.0) + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertIsNotNone(cb_grad) + self.assertGreater(cb_grad.abs().sum().item(), 0.0) + + def test_ste_codebook_grad_is_detached_on_recon_path(self) -> None: + # Contrast: STE detaches the aggregate, so the recon path gives the + # codebook no gradient (it trains via the commitment loss instead). + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + z = torch.randn(16, 8, requires_grad=True) + rvq(z).quantized_embeddings.sum().backward() + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertTrue(cb_grad is None or cb_grad.abs().sum().item() == 0.0) + + def test_ste_codebook_grad_flows_via_commitment_latents(self) -> None: + # The codebook trains via the commitment loss, which consumes ``latents``, + # so backward through latents MUST reach the codebook. Regression: a + # per-layer STE wrap once detached the codebook from latents, freezing it + # at init (commitment loss then grew unbounded while recon stayed fine). + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + rvq(torch.randn(16, 8)).latents.sum().backward() + cb_grad = rvq.layers[0].embedding.weight.grad + self.assertIsNotNone(cb_grad) + self.assertGreater(cb_grad.abs().sum().item(), 0.0) + + +class FaissResidualKmeansTest(unittest.TestCase): + """Tests for the FAISS residual K-Means warm-start helper.""" + + def test_faiss_residual_kmeans_per_layer_centers(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + samples = torch.randn(512, 6) + centers = faiss_residual_kmeans( + samples, [8, 4], {"niter": 5, "verbose": False, "seed": 1} + ) + self.assertEqual(len(centers), 2) + self.assertEqual(centers[0].shape, (8, 6)) + self.assertEqual(centers[1].shape, (4, 6)) + self.assertTrue(torch.isfinite(centers[0]).all()) + # Centroids come back on the input device (CPU fit, device-preserving). + self.assertEqual(centers[0].device, samples.device) + + +class ResidualVQBranchTest(unittest.TestCase): + """Coverage for the rotation-trick STE branch and the kmeans-init guard.""" + + def test_rotation_trick_rotates_gradient(self) -> None: + # The rotation trick keeps the STE forward value but ROTATES the input + # gradient. Plain STE also yields a finite non-zero grad, so a regression + # that reverted the Householder branch to ordinary STE would pass a smoke + # test. Pin the distinguishing property: same forward output, different + # input gradient. + def run(rotation_trick: bool): + torch.manual_seed(0) # identical codebook init + rvq = ResidualVectorQuantizer( + embed_dim=8, + n_layers=2, + n_embed=16, + forward_mode="ste", + rotation_trick=rotation_trick, + use_sinkhorn=False, + kmeans_init=False, + ) + rvq.train() + torch.manual_seed(1) # identical input + z = torch.randn(16, 8, requires_grad=True) + out = rvq(z).quantized_embeddings + out.sum().backward() + return out.detach(), z.grad + + out_rot, grad_rot = run(True) + out_ste, grad_ste = run(False) + self.assertTrue(torch.isfinite(grad_rot).all()) + self.assertGreater(grad_rot.abs().sum().item(), 0.0) + # Forward value is identical (the trick only changes the backward). + torch.testing.assert_close(out_rot, out_ste) + # ...but it genuinely rotates the gradient, unlike plain STE. + self.assertFalse(torch.allclose(grad_rot, grad_ste)) + + def test_kmeans_init_too_small_batch_raises(self) -> None: + # kmeans_init needs N >= max(codebook). A too-small first batch must + # raise a clear error (broadcast so all ranks abort together under DDP), + # not hang the non-rank-0 ranks on the centroid broadcast. + rvq = ResidualVectorQuantizer( + embed_dim=4, + n_layers=2, + n_embed=8, + kmeans_init=True, + use_sinkhorn=False, + ) + rvq.train() + with self.assertRaisesRegex(RuntimeError, "fewer rows than the largest"): + rvq(torch.randn(4, 4)) # 4 < max(codebook)=8 + + +class ResidualVectorQuantizerTest(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(0) + self.rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=3, n_embed=16, kmeans_init=False + ) + + def test_is_subclass(self) -> None: + self.assertIsInstance(self.rvq, ResidualQuantizer) + + def test_forward_output(self) -> None: + self.rvq.train() + out = self.rvq(torch.randn(5, 8)) + self.assertIsInstance(out, ResidualQuantizerOutput) + self.assertEqual(out.cluster_ids.shape, (5, 3)) + self.assertEqual(out.quantized_embeddings.shape, (5, 8)) + # latents: per-layer cumulative quantized vectors (B, n_layers, D). + self.assertEqual(out.latents.shape, (5, 3, 8)) + self.assertTrue(torch.isfinite(out.latents).all()) + + def test_forward_get_codes_consistent_eval(self) -> None: + """get_codes (shared base walk) matches forward's ids in eval.""" + self.rvq.eval() + x = torch.randn(6, 8) + fwd_ids = self.rvq(x).cluster_ids + gc_ids = self.rvq.get_codes(x) + self.assertFalse(gc_ids.requires_grad) + torch.testing.assert_close(gc_ids, fwd_ids) + + def test_faiss_kmeans_init_seeds_codebook(self) -> None: + try: + import faiss # noqa: F401 + except ImportError: + self.skipTest("faiss not installed") + torch.manual_seed(0) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ) + self.assertFalse(bool(rvq.initted.item())) + rvq.train() + # First training forward triggers the FAISS warm-start. + rvq(torch.randn(512, 8)) + self.assertTrue(bool(rvq.initted.item())) + for layer in rvq.layers: + self.assertTrue(torch.isfinite(layer.embedding.weight).all()) + self.assertGreater(layer.embedding.weight.abs().sum().item(), 0.0) + + +# --- Multi-process test for ResidualVectorQuantizer FAISS kmeans-init. --- +# Validates the DDP path of ``init_embed_``: the codebook is fit on rank 0 only +# and broadcast, so every rank ends with a bit-identical warm start. (The +# previous behavior averaged permutation-misaligned per-rank centroids, which the +# review flagged as a near-random init.) Uses NCCL on GPU when >=2 devices are +# available, else gloo/CPU. + +WORLD_SIZE = 2 + + +def _init(rank: int, world_size: int, port: int) -> torch.device: + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + use_cuda = torch.cuda.is_available() and torch.cuda.device_count() >= world_size + if use_cuda: + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return torch.device(f"cuda:{rank}") + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + return torch.device("cpu") + + +def _init_embed_worker(rank: int, world_size: int, port: int) -> None: + device = _init(rank, world_size, port) + # Rank-distinct data: a per-rank average/init would diverge; only a + # broadcast-from-rank0 init yields identical codebooks. + torch.manual_seed(rank) + rvq = ResidualVectorQuantizer( + embed_dim=8, n_layers=2, n_embed=16, kmeans_init=True + ).to(device) + rvq.train() + rvq(torch.randn(512, 8, device=device)) # first forward triggers init_embed_ + assert bool(rvq.initted.item()), f"rank{rank}: not initialized" + + for layer in rvq.layers: + w = layer.embedding.weight.detach().clone() + wmin, wmax = w.clone(), w.clone() + dist.all_reduce(wmin, op=dist.ReduceOp.MIN) + dist.all_reduce(wmax, op=dist.ReduceOp.MAX) + assert torch.allclose(wmin, wmax), ( + f"rank{rank}: codebook differs across ranks (init not broadcast)" + ) + dist.destroy_process_group() + + +class ResidualVectorQuantizerDistTest(unittest.TestCase): + """2-rank test for the FAISS kmeans-init broadcast.""" + + def test_init_embed_broadcast(self) -> None: + port = misc_util.get_free_port() + ctx = mp.get_context("spawn") + procs = [] + for rank in range(WORLD_SIZE): + p = ctx.Process(target=_init_embed_worker, args=(rank, WORLD_SIZE, port)) + p.start() + procs.append(p) + for i, p in enumerate(procs): + p.join() + if p.exitcode != 0: + raise RuntimeError(f"worker-{i} failed (exitcode={p.exitcode}).") + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/modules/sid/types.py b/tzrec/modules/sid/types.py index 2f0cf3c60..07c31fc9a 100644 --- a/tzrec/modules/sid/types.py +++ b/tzrec/modules/sid/types.py @@ -9,13 +9,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Data types for SID generation: output tuples shared across quantizers.""" +"""Data types for SID generation: enums and output tuples shared across quantizers.""" +from enum import Enum from typing import NamedTuple import torch +class QuantizeForwardMode(Enum): + """Forward mode for vector quantization (RQ-VAE backend). + + Attributes: + GUMBEL_SOFTMAX: use Gumbel-Softmax reparameterization. + STE: use Straight-Through Estimator. + """ + + GUMBEL_SOFTMAX = 1 + STE = 2 + + class QuantizeOutput(NamedTuple): """One quantize layer's output. @@ -26,3 +39,22 @@ class QuantizeOutput(NamedTuple): embeddings: torch.Tensor ids: torch.Tensor + + +class ResidualQuantizerOutput(NamedTuple): + """Output of the residual quantization module (RQ-VAE backend). + + The per-layer cumulative quantized vectors are exposed as ``latents`` so the + model-side commitment loss + (:class:`~tzrec.loss.sid_commitment_loss.SidCommitmentLoss`) can consume them. + + Attributes: + cluster_ids (Tensor): codebook indices per layer, shape (B, n_layers). + quantized_embeddings (Tensor): sum of quantized embeddings, shape (B, D). + latents (Tensor): per-layer cumulative quantized vectors, shape + (B, n_layers, D) (``latents[:, i]`` is the sum after layer ``i``). + """ + + cluster_ids: torch.Tensor + quantized_embeddings: torch.Tensor + latents: torch.Tensor diff --git a/tzrec/modules/sid/vector_quantize.py b/tzrec/modules/sid/vector_quantize.py new file mode 100644 index 000000000..4c41b3fbd --- /dev/null +++ b/tzrec/modules/sid/vector_quantize.py @@ -0,0 +1,230 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single codebook vector quantization with Sinkhorn uniform assignment.""" + +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F + +from tzrec.modules.sid.quantize_layer import QuantizeLayer +from tzrec.modules.sid.types import ( + QuantizeForwardMode, + QuantizeOutput, +) + + +@torch.no_grad() +def _sinkhorn( + cost: torch.Tensor, + n_iters: int = 5, + epsilon: float = 10.0, +) -> torch.Tensor: + """Sinkhorn-Knopp algorithm for optimal-transport based uniform assignment. + + Transforms a distance matrix into a soft assignment matrix via exponential + kernel and alternating row-column normalization, approximating a doubly + stochastic matrix to ensure uniform codebook utilization. Row sums are + all-reduced across ranks when a process group is initialized. + + Args: + cost (Tensor): distance matrix, shape (B, K) where K is codebook size. + IMPORTANT: must be z-score normalized and shifted to non-negative + before calling this function to avoid numerical overflow. + n_iters (int): number of Sinkhorn iterations. Default: 5. + epsilon (float): sharpness parameter for exp(-cost * epsilon). + Larger values produce sharper assignments. Default: 10.0. + + Returns: + Tensor: assignment matrix, shape (B, K). + Use Q.argmax(dim=-1) externally to get codebook indices. + """ + Q = torch.exp(-cost * epsilon).t() + + if dist.is_initialized(): + B = Q.size(1) * dist.get_world_size() + else: + B = Q.size(1) + K = Q.size(0) + + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + 1e-8 + + for _ in range(n_iters): + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + 1e-8 + Q /= K + + Q /= torch.sum(Q, dim=0, keepdim=True) + 1e-8 + Q /= B + + Q *= B + return Q.t() + + +class VectorQuantizeLayer(QuantizeLayer): + """Single codebook vector quantization layer (RQ-VAE backend). + + A gradient-trained ``nn.Embedding`` codebook (the VQ ``QuantizeLayer``), + sibling of the K-Means backend's ``KMeansQuantizeLayer``. Maps inputs to a + codebook entry via :meth:`quantize`. Loss-free: the commitment loss is + computed model-side by + :class:`tzrec.loss.sid_commitment_loss.SidCommitmentLoss` + over the quantizer's per-layer ``latents``. Sinkhorn optimal-transport + assignment optionally balances codebook usage in training. + + Args: + embed_dim (int): dimension of each codebook embedding. + n_embed (int): number of codebook entries. + forward_mode (QuantizeForwardMode): quantization forward mode, + either GUMBEL_SOFTMAX or STE. Default: STE. + distance_type (str): distance metric, 'l2' or 'cosine'. + Default: 'l2'. + use_sinkhorn (bool): whether to use Sinkhorn uniform assignment + during training. Default: True. + sinkhorn_iters (int): number of Sinkhorn iterations. Default: 5. + sinkhorn_epsilon (float): Sinkhorn sharpness parameter for + exp(-cost * epsilon). Default: 10.0. + gumbel_temperature (float): Gumbel-Softmax temperature (tau), used only + in GUMBEL_SOFTMAX training. Default: 1.0. + """ + + def __init__( + self, + embed_dim: int, + n_embed: int, + forward_mode: QuantizeForwardMode = QuantizeForwardMode.STE, + distance_type: str = "l2", + use_sinkhorn: bool = True, + sinkhorn_iters: int = 5, + sinkhorn_epsilon: float = 10.0, + gumbel_temperature: float = 1.0, + ) -> None: + super().__init__(n_embed=n_embed, embed_dim=embed_dim) + _is_gumbel = forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX + assert not (use_sinkhorn and _is_gumbel), ( + "use_sinkhorn=True is incompatible with forward_mode=GUMBEL_SOFTMAX: " + "Sinkhorn drives `ids` (balanced assignment) while Gumbel drives " + "`emb` (nearest code), so the returned id and embedding diverge. " + "Use STE with Sinkhorn, or Gumbel-Softmax without Sinkhorn." + ) + # epsilon sharpens exp(-cost * epsilon); <= 0 flips the kernel and the + # (large, shifted) cost overflows to +Inf -> NaN assignments. + if use_sinkhorn and sinkhorn_epsilon <= 0: + raise ValueError(f"sinkhorn_epsilon must be > 0, got {sinkhorn_epsilon}") + self.forward_mode = forward_mode + self.distance_type = distance_type + self.use_sinkhorn = use_sinkhorn + self.sinkhorn_iters = sinkhorn_iters + self.sinkhorn_epsilon = sinkhorn_epsilon + self.gumbel_temperature = gumbel_temperature + + self.embedding = nn.Embedding(n_embed, embed_dim) + nn.init.kaiming_uniform_(self.embedding.weight) + + def _compute_distances(self, x: torch.Tensor) -> torch.Tensor: + """Compute L2/cosine distances between inputs and codebook entries. + + Not ``no_grad``: Gumbel calls this directly for the encoder gradient; + the STE/Sinkhorn path calls it inside ``no_grad`` in + :meth:`_find_nearest_embedding`. + + Args: + x (Tensor): input vectors, shape (B, D). + + Returns: + Tensor: pairwise distances, shape (B, n_embed). + """ + codebook = self.embedding.weight + + if self.distance_type == "l2": + distances = torch.cdist(x, codebook, p=2).pow(2) + elif self.distance_type == "cosine": + x_norm = F.normalize(x, p=2, dim=1) + codebook_norm = F.normalize(codebook, p=2, dim=1) + distances = -torch.matmul(x_norm, codebook_norm.t()) + else: + raise ValueError( + f"Unsupported distance_type '{self.distance_type}', " + f"choose from ('l2', 'cosine')" + ) + return distances + + @torch.no_grad() + def _find_nearest_embedding(self, x: torch.Tensor) -> torch.Tensor: + """Find the nearest codebook id for each input vector. + + During training with use_sinkhorn=True, applies z-score + normalization + non-negative shift before Sinkhorn assignment. + Otherwise falls back to argmin. + + Args: + x (Tensor): input vectors, shape (B, D). + + Returns: + Tensor: codebook indices, shape (B,). + """ + distances = self._compute_distances(x) + + if self.training and self.use_sinkhorn: + # Sinkhorn requires non-negative cost; z-score then shift. + std, mean = torch.std_mean(distances, unbiased=False) + distances = (distances - mean) / std.add(1e-12) + distances = distances - distances.min() + + Q = _sinkhorn( + distances, + n_iters=self.sinkhorn_iters, + epsilon=self.sinkhorn_epsilon, + ) + ids = Q.argmax(dim=-1) + else: + ids = distances.argmin(dim=-1) + + return ids + + def quantize(self, x: torch.Tensor) -> QuantizeOutput: + """Assign ``x`` to the codebook (the :class:`QuantizeLayer` interface). + + Commitment loss is computed by the caller; device follows ``x``, so this + runs on CPU or GPU unchanged. The Gumbel temperature is the + ``gumbel_temperature`` init parameter. + + Args: + x (Tensor): input vectors, shape (B, D). + + Returns: + QuantizeOutput: named tuple of (embeddings, ids). + """ + if self.training and self.forward_mode == QuantizeForwardMode.GUMBEL_SOFTMAX: + logits = -self._compute_distances(x) + weights = F.gumbel_softmax( + logits, tau=self.gumbel_temperature, hard=True, dim=-1 + ) + emb = weights @ self.embedding.weight + ids = weights.argmax(dim=-1) + return QuantizeOutput(embeddings=emb, ids=ids) + + # Return the RAW codebook vector (no per-layer STE wrap): the aggregate + # STE in ResidualVectorQuantizer.forward routes the encoder gradient, + # while a wrap here would detach the codebook from ``latents`` and freeze + # it at init. + ids = self._find_nearest_embedding(x) + return QuantizeOutput(embeddings=self.embedding(ids), ids=ids) + + def get_codebook_embeddings(self) -> torch.Tensor: + """Return the codebook table, shape (n_embed, embed_dim).""" + return self.embedding.weight diff --git a/tzrec/modules/sid/vector_quantize_test.py b/tzrec/modules/sid/vector_quantize_test.py new file mode 100644 index 000000000..eab40e04b --- /dev/null +++ b/tzrec/modules/sid/vector_quantize_test.py @@ -0,0 +1,157 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from tzrec.modules.sid.types import QuantizeForwardMode +from tzrec.modules.sid.vector_quantize import ( + VectorQuantizeLayer, +) + + +class VectorQuantizeTest(unittest.TestCase): + """Tests for a single VectorQuantizeLayer layer.""" + + def test_l2_compute_distances(self) -> None: + layer = VectorQuantizeLayer(embed_dim=2, n_embed=2, distance_type="l2") + # Pin the codebook to (0,0) and (0,1) so distances are exact. + layer.embedding.weight.data.copy_(torch.tensor([[0.0, 0.0], [0.0, 1.0]])) + x = torch.tensor([[0.0, 0.0], [1.0, 0.0]]) + d = layer._compute_distances(x) + self.assertEqual(d.shape, (2, 2)) + # row0: dist² to (0,0)=0, to (0,1)=1; row1: to (0,0)=1, to (0,1)=2 + torch.testing.assert_close(d, torch.tensor([[0.0, 1.0], [1.0, 2.0]])) + + @parameterized.expand( + [ + ("ste_l2", QuantizeForwardMode.STE, "l2", True), + ("ste_cosine", QuantizeForwardMode.STE, "cosine", True), + ("ste_no_sinkhorn", QuantizeForwardMode.STE, "l2", False), + # Gumbel must run without Sinkhorn (the combo is asserted against). + ("gumbel_l2", QuantizeForwardMode.GUMBEL_SOFTMAX, "l2", False), + ] + ) + def test_train_forward(self, _name, mode, distance_type, use_sinkhorn) -> None: + torch.manual_seed(0) + vq = VectorQuantizeLayer( + embed_dim=8, + n_embed=16, + forward_mode=mode, + distance_type=distance_type, + use_sinkhorn=use_sinkhorn, + ) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + out = vq.quantize(x) + self.assertEqual(out.embeddings.shape, (5, 8)) + self.assertEqual(out.ids.shape, (5,)) + self.assertTrue((out.ids >= 0).all() and (out.ids < 16).all()) + self.assertTrue(torch.isfinite(out.embeddings).all()) + + def test_sinkhorn_balances_assignment(self) -> None: + """Sinkhorn spreads clustered points across codes; argmin collapses them. + + Functional check (not just shape/finiteness): feed points clustered at + one anchor — argmin sends all to that code, while Sinkhorn's uniform + assignment must use more than one code. + """ + torch.manual_seed(0) + vq = VectorQuantizeLayer( + embed_dim=2, n_embed=4, use_sinkhorn=True, sinkhorn_iters=10 + ) + vq.train() + with torch.no_grad(): + vq.embedding.weight.copy_( + torch.tensor([[0.0, 0.0], [10.0, 0.0], [0.0, 10.0], [10.0, 10.0]]) + ) + x = torch.randn(16, 2) * 0.1 # all clustered at anchor 0 + sinkhorn_ids = vq.quantize(x).ids + vq.use_sinkhorn = False + argmin_ids = vq.quantize(x).ids + self.assertEqual(argmin_ids.unique().numel(), 1) + self.assertGreater(sinkhorn_ids.unique().numel(), 1) + + def test_sinkhorn_gumbel_combo_rejected(self) -> None: + """Sinkhorn + Gumbel would desync `ids` and `emb`; constructor rejects it.""" + with self.assertRaisesRegex(AssertionError, "GUMBEL_SOFTMAX"): + VectorQuantizeLayer( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=True, + ) + + def test_sinkhorn_epsilon_must_be_positive(self) -> None: + """Reject a non-positive sinkhorn_epsilon (it overflows exp(-cost*eps)).""" + with self.assertRaisesRegex(ValueError, "sinkhorn_epsilon"): + VectorQuantizeLayer( + embed_dim=8, n_embed=16, use_sinkhorn=True, sinkhorn_epsilon=0.0 + ) + + def test_train_forward_backward_reaches_codebook(self) -> None: + torch.manual_seed(0) + vq = VectorQuantizeLayer(embed_dim=8, n_embed=16, use_sinkhorn=False) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + out = vq.quantize(x) + out.embeddings.sum().backward() + # The layer returns the raw codebook vector, so gradient reaches the + # codebook (the encoder STE is applied on the aggregate by the RVQ). + self.assertIsNotNone(vq.embedding.weight.grad) + self.assertTrue(torch.isfinite(vq.embedding.weight.grad).all()) + + def test_eval_forward_is_plain_lookup(self) -> None: + torch.manual_seed(0) + vq = VectorQuantizeLayer(embed_dim=4, n_embed=8) + vq.eval() + x = torch.randn(3, 4) + out = vq.quantize(x) + # In eval, emb == embedding(ids) exactly. + torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) + + def test_gumbel_train_ids_match_embedding(self) -> None: + # In gumbel training the saved code must index the codebook vector + # actually used (the hard sample), so emb forward == embedding(ids). + # (Under the old code ids came from argmin and could disagree with the + # gumbel-sampled embedding.) + torch.manual_seed(0) + vq = VectorQuantizeLayer( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=False, + ) + vq.train() + out = vq.quantize(torch.randn(5, 8)) + torch.testing.assert_close(out.embeddings, vq.embedding(out.ids)) + + def test_gumbel_train_distances_are_differentiable(self) -> None: + # Gumbel needs the assignment differentiable: grad must reach the input. + torch.manual_seed(0) + vq = VectorQuantizeLayer( + embed_dim=8, + n_embed=16, + forward_mode=QuantizeForwardMode.GUMBEL_SOFTMAX, + use_sinkhorn=False, + ) + vq.train() + x = torch.randn(5, 8, requires_grad=True) + vq.quantize(x).embeddings.sum().backward() + self.assertIsNotNone(x.grad) + self.assertTrue(torch.isfinite(x.grad).all()) + self.assertGreater(x.grad.abs().sum().item(), 0.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/protos/loss.proto b/tzrec/protos/loss.proto index 6468cc60e..97d62c685 100644 --- a/tzrec/protos/loss.proto +++ b/tzrec/protos/loss.proto @@ -9,6 +9,36 @@ message LossConfig { JRCLoss jrc_loss = 4; BinaryFocalLoss binary_focal_loss = 5; } + // Losses for semantic-ID (SID) generation models (SidRqvae). A SID model + // lists one LossConfig per term it trains on (a reconstruction loss, the + // commitment loss, and optionally the pair contrastive loss). + oneof sid_loss { + SidReconLoss recon_loss = 6; + SidCommitmentLoss commitment_loss = 7; + SidContrastiveLoss contrastive_loss = 8; + } +} + +// RQ-VAE reconstruction loss (input vs. decoder output). +message SidReconLoss { + // Distance for the reconstruction term: "l2" (mse), "l1" or "cos". + optional string recon_type = 1 [default = "l2"]; +} + +// RQ-VAE commitment loss between the encoder output and the per-layer +// cumulative quantized vectors. +message SidCommitmentLoss { + // Commitment loss weights [w1, w2] (encoder-toward-quant, quant-toward- + // encoder). Defaults to [1.0, 0.5] when unset. + repeated float latent_weight = 1; + // Distance used for the commitment term: "l2", "l1" or "cos". + optional string commitment_type = 2 [default = "l2"]; +} + +// Enables the pair contrastive (masked InfoNCE) objective for a SID model. The +// paired-feature wiring lives on the model (SidRqvae.contrastive_config); this +// just turns the objective on (any loss hyperparameters would go here). +message SidContrastiveLoss { } message BinaryCrossEntropy { diff --git a/tzrec/protos/model.proto b/tzrec/protos/model.proto index 58b719a7a..d2c34ae0f 100644 --- a/tzrec/protos/model.proto +++ b/tzrec/protos/model.proto @@ -79,7 +79,7 @@ message ModelConfig { RocketLaunching rocket_launching = 500; // SID generation models - // (600 is reserved for SidRqvae, arriving in the follow-up PR) + SidRqvae sid_rqvae = 600; SidRqkmeans sid_rqkmeans = 601; } diff --git a/tzrec/protos/models/sid_model.proto b/tzrec/protos/models/sid_model.proto index e51462efa..23efdce93 100644 --- a/tzrec/protos/models/sid_model.proto +++ b/tzrec/protos/models/sid_model.proto @@ -15,10 +15,81 @@ message FaissKmeansConfig { optional bool verbose = 7; } +message SinkhornConfig { + // Number of Sinkhorn iterations. + optional uint32 iters = 1 [default = 5]; + // Sinkhorn sharpness parameter (epsilon). + optional float epsilon = 2 [default = 10.0]; + // Whether Sinkhorn uniform assignment is enabled. Default: true. + // Set ``enabled: false`` to turn Sinkhorn off while still providing + // the sub-config (e.g. to override iters/epsilon for an A/B comparison). + optional bool enabled = 3 [default = true]; +} + +// Pair contrastive (dual-encoder) wiring for SidRqvae: which paired feature +// group to encode and which group flags the contrastive-pair rows. This is +// model structure / input contract (declared on the model), not loss config. +message ContrastiveConfig { + // Name of the second (paired) embedding FEATURE GROUP (built by the same + // EmbeddingGroup as the main input; same total dim as `feature_group`). + required string pair_feature_group = 1; + // Name of the per-row pair-flag FEATURE GROUP (a single raw feature, + // dim 1; >0.5 = contrastive pair). + required string pair_flag_feature_group = 2; +} + +message SidRqvae { + // === Network structure === + // Quantization latent dimension (encoder output / codebook dim). + optional uint32 embed_dim = 2 [default = 64]; + // Encoder hidden layer sizes, e.g. [256, 128]. + // Defaults to [input_dim // 2] when unset. + repeated uint32 hidden_dims = 3; + // Per-layer codebook size, e.g. [256, 256, 256]. + // List length is the number of residual quantization layers; + // non-uniform codebooks such as [512, 256, 128] are supported. + repeated uint32 codebook = 5; + + // === Quantization strategy === + // VQ forward mode: "ste" or "gumbel_softmax". + optional string forward_mode = 6 [default = "ste"]; + // L2-normalize residuals before each quantization layer. + optional bool normalize_residuals = 7 [default = false]; + // Distance metric: "l2" or "cosine". + optional string distance_type = 9 [default = "l2"]; + // STE rotation trick. + optional bool rotation_trick = 12 [default = false]; + // KMeans codebook initialization on first training forward. Default false. + // Best-effort warm-start only: it seeds the codebook from a SINGLE batch + // (the encoder is still near-random at step 1), and gradient training + + // commitment loss refine it afterward. Requirements/caveats: + // * batch_size >= max(codebook) (FAISS requires N >= K) — so it is + // unusable with large codebooks at typical batch sizes (e.g. an 8192 + // codebook needs a >= 8192-row first batch); under DDP a too-small + // rank-0 batch now raises on all ranks instead of hanging. + // * one batch is statistically thin (few points per centroid); for a + // data-rich fit use the SidRqkmeans model (reservoir-sampled) instead. + // Opt-in for these reasons. + optional bool kmeans_init = 13 [default = false]; + + // === Optional sub-module configs === + // Sinkhorn uniform assignment. Default behavior when this block is + // omitted: enabled with iters=5, epsilon=10.0. Include the sub-config + // to override params; set ``enabled: false`` inside it to disable. + optional SinkhornConfig sinkhorn_config = 15; + // Pair contrastive (dual-encoder) structure: when set, the model encodes a + // second (paired) feature group and runs the contrastive path. This declares + // the model's input contract + topology; the contrastive OBJECTIVE is enabled + // separately by a `contrastive_loss` entry in ModelConfig.losses (both must + // be set together). + optional ContrastiveConfig contrastive_config = 16; + + // Reconstruction, commitment and (optional) contrastive losses are configured + // via ModelConfig.losses (the LossConfig ``sid_loss`` oneof); only the + // contrastive feature wiring above lives on this message. +} + message SidRqkmeans { - // Input embedding dimension (K-Means runs directly on raw embeddings, - // no encoder). - optional uint32 input_dim = 1 [default = 512]; // Per-layer cluster counts, e.g. [256, 256, 256]. // List length is the number of residual quantization layers. Entries // may differ per layer (non-uniform codebooks such as [256, 512, 1024] @@ -36,7 +107,4 @@ message SidRqkmeans { // what FAISS subsamples to internally (default 256), so no training points // are wasted. optional uint32 train_sample_size = 6 [default = 0]; - - // Name of the item embedding feature inside the input Batch. - optional string embedding_feature_name = 40 [default = "item_emb"]; } diff --git a/tzrec/tests/configs/sid_rqkmeans_mock.config b/tzrec/tests/configs/sid_rqkmeans_mock.config index 0e6dec907..cd95dac8d 100644 --- a/tzrec/tests/configs/sid_rqkmeans_mock.config +++ b/tzrec/tests/configs/sid_rqkmeans_mock.config @@ -42,12 +42,10 @@ model_config { group_type: DEEP } sid_rqkmeans { - input_dim: 16 codebook: 16 codebook: 16 codebook: 16 normalize_residuals: false - embedding_feature_name: "item_emb" faiss_kmeans_kwargs { niter: 5 seed: 42 diff --git a/tzrec/tests/configs/sid_rqvae_contrastive_mock.config b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config new file mode 100644 index 000000000..0fbe9a4cc --- /dev/null +++ b/tzrec/tests/configs/sid_rqvae_contrastive_mock.config @@ -0,0 +1,97 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqvae_contrastive_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 2 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "pair_emb" + expression: "item:pair_embedding" + value_dim: 16 + } +} +feature_configs { + raw_feature { + feature_name: "is_pair" + expression: "item:is_pair" + value_dim: 1 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + feature_groups { + group_name: "pair" + feature_names: "pair_emb" + group_type: DEEP + } + feature_groups { + group_name: "pair_flag" + feature_names: "is_pair" + group_type: DEEP + } + sid_rqvae { + embed_dim: 8 + hidden_dims: 16 + codebook: 16 + codebook: 16 + codebook: 16 + forward_mode: "ste" + kmeans_init: false + # "pair" shares the main encoder so must match its dim; "pair_flag" flags + # rows (>0.5 = pair). + contrastive_config { + pair_feature_group: "pair" + pair_flag_feature_group: "pair_flag" + } + } + losses { + recon_loss { + recon_type: "l2" + } + } + losses { + commitment_loss { + latent_weight: 1.0 + latent_weight: 0.5 + } + } + losses { + contrastive_loss { + } + } +} diff --git a/tzrec/tests/configs/sid_rqvae_mock.config b/tzrec/tests/configs/sid_rqvae_mock.config new file mode 100644 index 000000000..6f4691136 --- /dev/null +++ b/tzrec/tests/configs/sid_rqvae_mock.config @@ -0,0 +1,63 @@ +train_input_path: "" +eval_input_path: "" +model_dir: "experiments/sid_rqvae_mock" +train_config { + sparse_optimizer { + adagrad_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 2 + save_checkpoints_epochs: 1 +} +eval_config { +} +data_config { + batch_size: 256 + dataset_type: ParquetDataset + fg_mode: FG_DAG + num_workers: 2 +} +feature_configs { + raw_feature { + feature_name: "item_emb" + expression: "item:embedding" + value_dim: 16 + } +} +model_config { + feature_groups { + group_name: "deep" + feature_names: "item_emb" + group_type: DEEP + } + sid_rqvae { + embed_dim: 8 + hidden_dims: 16 + codebook: 16 + codebook: 16 + codebook: 16 + forward_mode: "ste" + kmeans_init: false + } + losses { + recon_loss { + recon_type: "l2" + } + } + losses { + commitment_loss { + latent_weight: 1.0 + latent_weight: 0.5 + } + } +} diff --git a/tzrec/tests/sid_integration_test.py b/tzrec/tests/sid_integration_test.py index 53f24a1d3..4107c4b0f 100644 --- a/tzrec/tests/sid_integration_test.py +++ b/tzrec/tests/sid_integration_test.py @@ -44,11 +44,17 @@ def tearDown(self): if self.success and os.path.exists(self.test_dir): shutil.rmtree(self.test_dir) - def _prepare_config(self, num_rows: int, dim: int) -> str: + def _prepare_config( + self, + num_rows: int, + dim: int, + base_config: str = "tzrec/tests/configs/sid_rqkmeans_mock.config", + ) -> str: """Write an embedding parquet + a SID config pointed at it. - Single dense ``embedding`` column, no labels — SID reads the item - embedding straight from the batch. Returns the saved config path. + Single dense ``embedding`` column, no labels — the config's FG maps it + to the ``item_emb`` feature, which its ``deep`` feature_group feeds to + the model's EmbeddingGroup. Returns the saved config path. """ data_dir = os.path.join(self.test_dir, "sid_data") os.makedirs(data_dir, exist_ok=True) @@ -61,9 +67,7 @@ def _prepare_config(self, num_rows: int, dim: int) -> str: # train_input_path set -> load_config_for_test uses it as-is (the # FG_DAG auto-mock path is match-model-specific; SID is single-table). - config = config_util.load_pipeline_config( - "tzrec/tests/configs/sid_rqkmeans_mock.config" - ) + config = config_util.load_pipeline_config(base_config) config.train_input_path = data_glob config.eval_input_path = data_glob config_path = os.path.join(self.test_dir, "sid.config") @@ -117,6 +121,50 @@ def test_sid_rqkmeans_train_eval(self): self.assertLess(metrics["rel_loss"], 1.0) self.assertGreater(metrics["unique_sid_ratio"], 0.0) + @unittest.skipIf( + torch.cuda.is_available(), + "the SID integration tests run on the CPU CI job; forcing CPU on a " + "CUDA-built (GPU) image is unreliable.", + ) + def test_sid_rqvae_train_eval(self): + """End-to-end SidRqvae train -> checkpoint -> eval (gradient-trained). + + Exercises the full RQ-VAE pipeline (encode -> quantize -> decode, + commitment + reconstruction loss, gradient training, checkpoint, eval + metrics). On random data the model need not beat the all-zero baseline, + so only finiteness + nonzero SID variety are asserted (not rel_loss<1). + """ + config_path = self._prepare_config( + num_rows=2048, + dim=16, + base_config="tzrec/tests/configs/sid_rqvae_mock.config", + ) + self.success = utils.test_train_eval(config_path, self.test_dir) + # train_eval writes train_eval_result_v2.txt; a standalone eval pass + # (like the rqkmeans test) writes eval_result.txt from the checkpoint. + if self.success: + self.success = utils.test_eval( + os.path.join(self.test_dir, "pipeline.config"), self.test_dir + ) + self.assertTrue(self.success) + # save_checkpoints_epochs=1 persists a checkpoint during training. + self.assertTrue( + glob.glob(os.path.join(self.test_dir, "train", "model.ckpt-*")), + "no checkpoint persisted", + ) + result_path = os.path.join(self.test_dir, "train", "eval_result.txt") + self.assertTrue(os.path.exists(result_path), "no eval_result.txt produced") + with open(result_path) as f: + lines = [ln for ln in f.read().splitlines() if ln.strip()] + self.assertTrue(lines, "eval_result.txt is empty") + metrics = json.loads(lines[-1]) + for key in ("mse", "rel_loss", "unique_sid_ratio"): + self.assertIn(key, metrics) + self.assertTrue( + math.isfinite(metrics[key]), f"{key} not finite: {metrics[key]}" + ) + self.assertGreater(metrics["unique_sid_ratio"], 0.0) + if __name__ == "__main__": unittest.main() diff --git a/tzrec/version.py b/tzrec/version.py index c9acfa881..489ed9a3a 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.2.21" +__version__ = "1.2.22"