Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
89fad00
[feat] SID: port RQ-VAE and RQ-KMeans semantic-ID generation models
WhiteSwan1 Jun 1, 2026
1f494e6
[refactor] SID proto: use repeated numeric fields, drop _sid_helpers
WhiteSwan1 Jun 1, 2026
ebdbd34
[refactor] SID modules: introduce ResidualQuantizer abstract base + r…
WhiteSwan1 Jun 1, 2026
5c6f67a
[refactor] SID models: add BaseSidModel parent for shared init/metric
WhiteSwan1 Jun 1, 2026
d1cf153
[chore] SID: bump copyright year 2024 -> 2026 on newly added files
WhiteSwan1 Jun 1, 2026
b2bded0
[refactor] SID: drop the redundant RQKMeans wrapper
WhiteSwan1 Jun 1, 2026
c6efa87
[refactor] SID: fold RQVAE module into SidRqvae model
WhiteSwan1 Jun 1, 2026
0bd7a44
[refactor] SID: address review nits (std_mean, test config, .cpu() doc)
WhiteSwan1 Jun 1, 2026
e867af2
[test] SID: add module-level unit tests for sid_generation
WhiteSwan1 Jun 1, 2026
78a3ce2
[refactor] SID clip_loss: use built-in differentiable all_gather + di…
WhiteSwan1 Jun 1, 2026
f4851a3
[test] SID: add multi-rank test for SidRqkmeans.on_train_end DDP path
WhiteSwan1 Jun 1, 2026
26bcecd
[refactor] SidRqkmeans.on_train_end: drop redundant empty-buffer hand…
WhiteSwan1 Jun 1, 2026
e758006
[refactor] SID: move CLIP loss into tzrec/loss
WhiteSwan1 Jun 1, 2026
d2a0032
[refactor] clip_loss: MaskedCLIPLoss subclasses _Loss; fold in all_ga…
WhiteSwan1 Jun 1, 2026
0a9ca3b
[refactor] SidRqkmeans: use config_util.config_to_kwargs for faiss kw…
WhiteSwan1 Jun 1, 2026
a0daba8
[perf] SidRqkmeans: self-tuning reservoir buffer (bounded host memory)
WhiteSwan1 Jun 1, 2026
1ac30d3
[refactor] RQ-VAE kmeans_init: use FAISS K-Means + broadcast (drop to…
WhiteSwan1 Jun 1, 2026
f4fe6d5
[feat] RQ-KMeans: support non-uniform codebooks (e.g. [256, 512, 1024])
WhiteSwan1 Jun 1, 2026
f9fdea1
[perf] RQ-KMeans train_offline: feed FAISS torch tensors, GPU-accelerate
WhiteSwan1 Jun 1, 2026
b300adb
[refactor] SID: hoist shared input_dim / normalize_residuals into Bas…
WhiteSwan1 Jun 1, 2026
345f133
[chore] SidRqvae: use self._input_dim directly, drop redundant local …
WhiteSwan1 Jun 1, 2026
0fbff37
[fix] SID: wire on_train_end lifecycle hook into BaseModel + main.py
WhiteSwan1 Jun 1, 2026
0441ff2
[refactor] SidRqvae: merge _recon_loss and _masked_recon_loss
WhiteSwan1 Jun 2, 2026
aa8110d
[refactor] SidRqvae: unify loss keys across recon + CLIP paths
WhiteSwan1 Jun 2, 2026
5aae8ff
[test] SID: merge sid_rqkmeans_dist_test into sid_rqkmeans_test
WhiteSwan1 Jun 2, 2026
d67f923
[chore] RQ-KMeans: drop unused all_initialized property
WhiteSwan1 Jun 2, 2026
3856bbc
[bugfix] KMeans predict: drop data-dependent chunking so forward is F…
WhiteSwan1 Jun 2, 2026
39c88a4
[guard] VectorQuantize: reject use_sinkhorn + GUMBEL_SOFTMAX combo
WhiteSwan1 Jun 2, 2026
3667170
[refactor] ResidualQuantizer: hoist shared residual walk into the base
WhiteSwan1 Jun 3, 2026
4b6e9b0
[review] SID base: address github-actions review on PR #538
WhiteSwan1 Jun 3, 2026
f568de0
[review] SID base: second round of github-actions feedback on PR #538
WhiteSwan1 Jun 3, 2026
9dc276d
[review] SID: address tiankongdeguiji review on PR #538
WhiteSwan1 Jun 3, 2026
9801931
[review] SID: honest unique_sid_ratio framing + update_metric docstring
WhiteSwan1 Jun 3, 2026
c3fa1ed
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 5, 2026
e97e742
[refactor] rename tzrec/modules/sid_generation -> tzrec/modules/sid; …
WhiteSwan1 Jun 5, 2026
b87989e
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 5, 2026
995b23e
[feat] SID: add SidRqkmeans model (FAISS-trained residual K-Means)
WhiteSwan1 Jun 5, 2026
c7f3a09
[review] SID: drop forced tail-checkpoint after on_train_end
WhiteSwan1 Jun 8, 2026
61ec842
[review] SID: address code-review findings on PR #539
WhiteSwan1 Jun 8, 2026
753f3fe
[review] SID: default normalize_residuals to False
WhiteSwan1 Jun 8, 2026
52c7452
[review] SID: encapsulation, comment, and import cleanups
WhiteSwan1 Jun 8, 2026
fbd973f
[review] SID: move FAISS fit-sample sizing into the quantizer
WhiteSwan1 Jun 8, 2026
893a627
[review] SID: log rank0 FAISS-fit failure with traceback
WhiteSwan1 Jun 8, 2026
3734fc2
[review] SID: clarify the reservoir ceil-div comment
WhiteSwan1 Jun 9, 2026
795c676
[review] SID: fix FAISS gpu kwarg + close test gaps from PR review
WhiteSwan1 Jun 9, 2026
2bb5abc
[review] SID: default FAISS fit to CPU + DDP fit-failure test
WhiteSwan1 Jun 9, 2026
33acbe6
[review] SID: log the FAISS fit device (CPU/GPU)
WhiteSwan1 Jun 9, 2026
25a1e30
Merge remote-tracking branch 'upstream/master' into sid-2-rqkmeans
WhiteSwan1 Jun 9, 2026
23c552c
[chore] bump version to 1.2.18
WhiteSwan1 Jun 9, 2026
3261c2c
[review] SID: address 23c552c review (test timeout, N>=K assert, cap …
WhiteSwan1 Jun 9, 2026
e6e4d00
Merge upstream/master into sid-2-rqkmeans; bump version to 1.2.19
WhiteSwan1 Jun 9, 2026
39017ab
[review] checkpoint_util: force only overrides the dedupe
WhiteSwan1 Jun 9, 2026
5afbd5e
[review] checkpoint maybe_save: clarify final vs force docstrings
WhiteSwan1 Jun 9, 2026
415b8a3
[refactor] SidRqkmeans: single-process only; raise under DDP
WhiteSwan1 Jun 9, 2026
b27eb7b
[refactor] SidRqkmeans: move DDP guard to __init__ (fail fast)
WhiteSwan1 Jun 9, 2026
6f7ae1d
[simplify] SidRqkmeans: drop dead max(1,...) cap clamp; fold test _bu…
WhiteSwan1 Jun 9, 2026
5827d5b
[style] ruff-format the __init__ DDP guard (collapse to one line)
WhiteSwan1 Jun 9, 2026
4e2e878
[refactor] SidRqkmeans: CPU-only — raise on visible CUDA, drop device…
WhiteSwan1 Jun 9, 2026
4773e2a
[simplify] train_offline: assert host input; single-copy float32 own
WhiteSwan1 Jun 9, 2026
df83d07
[refactor] KMeansLayer.predict: use torch.cdist; drop _squared_euclid…
WhiteSwan1 Jun 10, 2026
d037db7
[refactor] SidRqkmeans: drop input_embedding from predictions
WhiteSwan1 Jun 10, 2026
88856f3
[simplify] trim SID docstrings (predict provenance; stale SidRqvae xref)
WhiteSwan1 Jun 10, 2026
2fa312b
[refactor] extract reservoir sampling into ReservoirSampler (kmeans.py)
WhiteSwan1 Jun 10, 2026
e296c8d
[refactor] ReservoirSampler: log capacity + dim on construction
WhiteSwan1 Jun 10, 2026
892a8d2
[fix] SID code-review: fail-fast cap, skip pre-fit eval, dedup MSE, d…
WhiteSwan1 Jun 10, 2026
b14304a
[simplify] SID: raise (not assert) for cap guard; name normalize_resi…
WhiteSwan1 Jun 10, 2026
eb39b5e
[style] SID: trim verbose comments
WhiteSwan1 Jun 10, 2026
8bf50aa
[refactor] SID: move init_metric/update_metric to BaseSidModel + Rela…
WhiteSwan1 Jun 10, 2026
e8a3609
[test] SID: add sid_integration_test (train -> fit -> checkpoint -> e…
WhiteSwan1 Jun 10, 2026
3dfbde0
[test] checkpoint: verify force re-save overwrites the same step
WhiteSwan1 Jun 10, 2026
d67ccd1
[review] split quantizer tests by module; clarify copy=True
WhiteSwan1 Jun 10, 2026
6a736c5
[refactor] drop CheckpointManager force param; SID uses no periodic c…
WhiteSwan1 Jun 10, 2026
5bc89d4
[refactor] typed FaissKmeansConfig proto; drop Struct + _coerce_proto…
WhiteSwan1 Jun 10, 2026
feeb4af
[refactor] add QuantizeLayer base; KMeansLayer -> KMeansQuantizeLayer
WhiteSwan1 Jun 10, 2026
a5d43b2
[refactor] unify reconstruction key to x_hat; drop _reconstruction hook
WhiteSwan1 Jun 10, 2026
c4c361a
[style] SID: trim redundant comments
WhiteSwan1 Jun 10, 2026
db7f2be
[refactor] QuantizeLayer: make lookup concrete in the base
WhiteSwan1 Jun 10, 2026
ed12cff
[refactor] QuantizeLayer: own n_clusters/n_features in the base
WhiteSwan1 Jun 10, 2026
d2697eb
[refactor] SID: extract QuantizeLayer ABC; rename kmeans -> kmeans_qu…
WhiteSwan1 Jun 10, 2026
097e9eb
[docs] checkpoint_util: tighten maybe_save `final` param docstring
WhiteSwan1 Jun 10, 2026
a9a889c
[fix] SID: review fixes + fail-fast validation; fix integration test …
WhiteSwan1 Jun 10, 2026
3b41df9
[review] SID: doc fixes, negative tests, stronger integration assertions
WhiteSwan1 Jun 10, 2026
5f5af01
[review] SID: drop _extract_feature width guard (embedding width is n…
WhiteSwan1 Jun 10, 2026
43e84ca
[fix] SID integration test: skip on CUDA, run on CPU CI
WhiteSwan1 Jun 11, 2026
949e438
Merge pr-539 (SID RQ-Kmeans + QuantizeLayer foundation) into feat/sid…
WhiteSwan1 Jun 11, 2026
85b9d40
[refactor] RQ-VAE: build on #539 QuantizeLayer ABC; retire old kmeans.py
WhiteSwan1 Jun 11, 2026
c838aec
[simplify] SID: drop VectorQuantize.forward delegator; trim redundant…
WhiteSwan1 Jun 11, 2026
a9b8d18
[simplify] SID: move faiss_residual_kmeans to its only user (RVQ module)
WhiteSwan1 Jun 11, 2026
9194ee0
[fix] RVQ.get_codebook_embeddings: detach the read-only accessor
WhiteSwan1 Jun 11, 2026
2284e86
[fix] SID code-review: CLIP empty-mask NaN, latent_weight validation,…
WhiteSwan1 Jun 11, 2026
12bd93e
[fix] RQ-VAE: make Gumbel-Softmax forward_mode actually functional (#1)
WhiteSwan1 Jun 11, 2026
15c5210
[style] SID: trim verbose comments and docstrings
WhiteSwan1 Jun 11, 2026
1f0aa4b
Merge master (official #539 squash + #542 doc) into feat/sid_abstract
WhiteSwan1 Jun 11, 2026
441cf19
[simplify] RQ-VAE: dedup gumbel predicate; move latent_weight check t…
WhiteSwan1 Jun 11, 2026
ba1d7f9
[fix] code-review: structural CLIP mask fill (finfo.min) + drop dead …
WhiteSwan1 Jun 15, 2026
c510212
[refactor] SID: share one-layer FAISS fit; tests for modified modules…
WhiteSwan1 Jun 15, 2026
6596b9e
[style] SID: ruff-format residual_quantizer_test (collapse one assert…
WhiteSwan1 Jun 15, 2026
9c2e872
[style] SID: simplify verbose comments/docstrings in RQ-VAE stack
WhiteSwan1 Jun 15, 2026
81c6bb7
[chore] SID: remove stray gumbel example config + mock-data generator
WhiteSwan1 Jun 15, 2026
8244af2
[refactor] SID: post-review cleanup + RQ-VAE robustness fixes
WhiteSwan1 Jun 15, 2026
cb39ca2
[simplify] SID: drop dead is_distributed param from _sinkhorn
WhiteSwan1 Jun 15, 2026
9cf7a62
[fix] SID: address PR review — logit_scale clamp, trim CLIP outputs, …
WhiteSwan1 Jun 15, 2026
4942670
[fix] SID: address 2nd PR review round — epsilon guard, doc fixes, na…
WhiteSwan1 Jun 16, 2026
2223634
[chore] bump version 1.2.19 -> 1.2.20
WhiteSwan1 Jun 16, 2026
fce258c
[refactor] SidRqvae._predict_rqvae: gate on _is_inference directly
WhiteSwan1 Jun 18, 2026
a752719
[test] SID: fix test-class colocation (foo.py -> foo_test.py)
WhiteSwan1 Jun 18, 2026
b3f8105
[fix] SID: address PR #545 review rounds 2-3
WhiteSwan1 Jun 18, 2026
a8c4592
[fix] SID: review D4/D5 (quantize API) + fix rqvae integration test
WhiteSwan1 Jun 18, 2026
dbab20b
[refactor] SID: config-driven losses via LossConfig sid_loss (review …
WhiteSwan1 Jun 18, 2026
8f0d882
[refactor] SID: /simplify cleanups on the config-driven loss refactor
WhiteSwan1 Jun 22, 2026
5a7370f
[refactor] SID: generalize CLIP loss -> masked InfoNCE; model owns st…
WhiteSwan1 Jun 22, 2026
2002daf
[refactor] SID: extract _masked_mean; inline recon distance into _sid…
WhiteSwan1 Jun 22, 2026
0c2f6e1
[refactor] SID: bind recon loss at init + merge ReconL2/L1/Cosine -> …
WhiteSwan1 Jun 22, 2026
105abe3
[feat] SID: consume framework EmbeddingGroup/build_input (drop _extra…
WhiteSwan1 Jun 22, 2026
e11faab
[feat] SID CLIP: fail-fast dim guard + mock config; trim comments
WhiteSwan1 Jun 22, 2026
e8c72fd
[refactor] SID: review cleanup — merge RQ-VAE pass, fix eval metric +…
WhiteSwan1 Jun 22, 2026
02b0b28
[chore] SID: drop comments that just restate their error messages/doc…
WhiteSwan1 Jun 22, 2026
2276c9d
[refactor] SID: land review #2 (framework MLP) + #10 (cdist)
WhiteSwan1 Jun 22, 2026
81394dc
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 22, 2026
b33401b
[chore] bump version to 1.2.21
WhiteSwan1 Jun 22, 2026
5407134
[refactor] SID: simplify per review + give BaseSidModel logic a test …
WhiteSwan1 Jun 22, 2026
dd57651
[bugfix] SID: restore codebook gradient — RVQ quantize returns raw ve…
WhiteSwan1 Jun 23, 2026
6731134
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 23, 2026
b203520
[chore] SID: drop redundant/misplaced quantizer tests + fix stale STE…
WhiteSwan1 Jun 23, 2026
4239390
[refactor] SID: extract CLIP wiring into SidRqvae._init_clip
WhiteSwan1 Jun 23, 2026
0bc8d18
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 23, 2026
c4cbd64
[chore] bump version to 1.2.22
WhiteSwan1 Jun 23, 2026
1d0fdcb
Merge remote-tracking branch 'upstream/master' into feat/sid_abstract
WhiteSwan1 Jun 24, 2026
a25be45
[chore] SID: drop comments that describe an absence / refactor history
WhiteSwan1 Jun 24, 2026
46b5ac2
[refactor] SID: default clip group names to None up front in _init_clip
WhiteSwan1 Jun 24, 2026
08c1bb2
[refactor] SID: move CLIP temperatures into MaskedInfoNCELoss
WhiteSwan1 Jun 24, 2026
714d3c0
[refactor] SID: restore the clamp+exp helper in MaskedInfoNCELoss
WhiteSwan1 Jun 24, 2026
210bfab
[refactor] SID: de-CLIP rename + uniform loss-modularization
WhiteSwan1 Jun 25, 2026
214c755
[refactor] SID: strip comments that restate logic in the contrastive …
WhiteSwan1 Jun 25, 2026
6a62c6a
[refactor] SID: strip restate comments (PR-wide comment refinement)
WhiteSwan1 Jun 25, 2026
527da18
[refactor] SID: streamline narration/rationale comments
WhiteSwan1 Jun 25, 2026
6eca7c7
[refactor] SID: ruff-format residual_vector_quantizer
WhiteSwan1 Jun 25, 2026
d8e24d9
[refactor] SID: SidReconLoss masked-mean; drop feature_group field (a…
WhiteSwan1 Jun 25, 2026
60eac78
[test] SID: strengthen reservoir phase-2 assertion
WhiteSwan1 Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions tzrec/loss/sid_commitment_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2026, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SidCommitmentLoss: VQ-VAE commitment loss for residual quantizers."""

from typing import Sequence

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss


class SidCommitmentLoss(_Loss):
"""Commitment loss between the encoder output and the quantized vectors.

Operates on a residual quantizer's per-layer cumulative quantized vectors
(the ``latents`` of :class:`~tzrec.modules.sid.types.ResidualQuantizerOutput`).
Both VQ-VAE directions are summed and averaged over the residual layers:

- ``loss1`` = encoder-toward-quant (gradient flows into the encoder), w1
- ``loss2`` = quant-toward-encoder (gradient flows into the codebook), w2

Args:
latent_weight (Sequence[float]): commitment weights ``[w1, w2]``.
Default: ``(1.0, 0.5)``.
commitment_type (str): distance, ``"l2"``, ``"l1"`` or ``"cos"``.
Default: ``"l2"``.
"""

def __init__(
self,
latent_weight: Sequence[float] = (1.0, 0.5),
commitment_type: str = "l2",
) -> None:
super().__init__()
if len(latent_weight) != 2:
raise ValueError(
f"latent_weight must have exactly 2 values [w1, w2], got "
f"{list(latent_weight)}"
)
assert commitment_type in ("l2", "l1", "cos"), (
f"commitment_type must be 'l2', 'l1' or 'cos', got {commitment_type!r}"
)
self.commitment_w1, self.commitment_w2 = latent_weight
self.commitment_type = commitment_type

def forward(self, encoder_out: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
"""Compute the commitment loss.

Args:
encoder_out (Tensor): encoder output (the quantizer input), shape
(B, D).
latents (Tensor): per-layer cumulative quantized vectors, shape
(B, n_layers, D).

Returns:
Tensor: scalar commitment loss (averaged over layers).
"""
x = encoder_out.unsqueeze(1) # (B, 1, D) -> broadcasts over layers
if self.commitment_type == "cos":
loss1 = (1 - F.cosine_similarity(x, latents.detach(), dim=-1)).mean()
loss2 = (1 - F.cosine_similarity(x.detach(), latents, dim=-1)).mean()
elif self.commitment_type == "l1":
loss1 = (x - latents.detach()).abs().mean()
loss2 = (x.detach() - latents).abs().mean()
else: # "l2"
loss1 = (x - latents.detach()).pow(2.0).mean()
loss2 = (x.detach() - latents).pow(2.0).mean()
return self.commitment_w1 * loss1 + self.commitment_w2 * loss2
65 changes: 65 additions & 0 deletions tzrec/loss/sid_commitment_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2026, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from tzrec.loss.sid_commitment_loss import SidCommitmentLoss


class SidCommitmentLossTest(unittest.TestCase):
"""Tests for the standalone SidCommitmentLoss module."""

@parameterized.expand([("l2",), ("l1",), ("cos",)])
def test_branch_runs_and_backprops(self, commitment_type) -> None:
"""Each commitment_type runs end-to-end; grad reaches both operands."""
torch.manual_seed(0)
loss_fn = SidCommitmentLoss(
latent_weight=(1.0, 0.5), commitment_type=commitment_type
)
B, L, D = 4, 3, 8
encoder_out = torch.randn(B, D, requires_grad=True)
latents = torch.randn(B, L, D, requires_grad=True)
out = loss_fn(encoder_out, latents)
self.assertEqual(out.shape, ())
self.assertTrue(torch.isfinite(out))
out.backward()
# loss1 (encoder-toward-quant) feeds encoder_out; loss2 feeds latents.
self.assertIsNotNone(encoder_out.grad)
self.assertIsNotNone(latents.grad)
self.assertTrue(torch.isfinite(encoder_out.grad).all())

def test_latent_weight_wrong_length_raises(self) -> None:
"""latent_weight must be exactly [w1, w2]."""
for bad in ([1.0], [1.0, 0.5, 0.25]):
with self.assertRaisesRegex(ValueError, "latent_weight"):
SidCommitmentLoss(latent_weight=bad)

def test_invalid_commitment_type_raises(self) -> None:
"""An unknown commitment_type is rejected."""
with self.assertRaisesRegex(AssertionError, "commitment_type"):
SidCommitmentLoss(commitment_type="bogus")

def test_weights_scale_the_two_directions(self) -> None:
"""w1/w2 weight the encoder-toward-quant / quant-toward-encoder terms."""
torch.manual_seed(0)
encoder_out = torch.randn(4, 8)
latents = torch.randn(4, 3, 8)
base = SidCommitmentLoss(latent_weight=(1.0, 0.5), commitment_type="l2")
zero = SidCommitmentLoss(latent_weight=(0.0, 0.0), commitment_type="l2")
self.assertGreater(base(encoder_out, latents).item(), 0.0)
self.assertEqual(zero(encoder_out, latents).item(), 0.0)


if __name__ == "__main__":
unittest.main()
194 changes: 194 additions & 0 deletions tzrec/loss/sid_contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) 2026, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Masked InfoNCE contrastive loss with distributed all-gather support."""

import math
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.distributed.nn as dist_nn
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss

# CLIP temperature init (reference CLIP: log(1 / 0.07)) and the cap applied
# before ``exp`` (reference CLIP clamps to ln(100)): an unbounded temperature
# would overflow to +Inf -> NaN grad -> corrupt param.
_LOGIT_SCALE_INIT = math.log(1 / 0.07)
_LOGIT_SCALE_MAX = math.log(100)


class SidContrastiveLoss(_Loss):
"""Masked InfoNCE pair-contrastive loss for mixed (paired + non-paired) batches.

Modality-agnostic: aligns two reconstructed "views" (``embed_a`` / ``embed_b``)
against each other and against their originals (``embed_a_ori`` /
``embed_b_ori``) with three symmetric InfoNCE terms (self/ori/cl). In a mixed
batch, non-pair rows (``pair_mask=False``) must not contribute and must not
serve as negatives; row/column masks achieve this without data-dependent
branching (``torch.compile``-friendly).

``forward`` takes the four ``(B, dim)`` view embeddings plus the ``(B,)`` pair
mask and returns the scalar mean of the three contrastive terms. The three
temperatures (self/ori/cl) are learnable parameters owned by this module;
``forward`` clamps (to <= ln(100)) and ``exp``s them.
"""

def __init__(self) -> None:
super().__init__()
self.labels: Optional[torch.Tensor] = None
self.last_local_batch_size: Optional[int] = None
self._rank = dist.get_rank() if dist.is_initialized() else 0
# Learnable contrastive temperatures, one per group (self / ori / cl);
# registered here so the loss module is self-contained.
self.logit_scale_self = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT)
self.logit_scale_cl = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT)
self.logit_scale_ori = nn.Parameter(torch.ones([]) * _LOGIT_SCALE_INIT)

@staticmethod
def _scaled(logit_scale: torch.Tensor) -> torch.Tensor:
# Clamp before exp so a large temperature can't overflow to +Inf -> NaN.
return logit_scale.clamp(max=_LOGIT_SCALE_MAX).exp()

@staticmethod
def _all_gather_with_grad(tensors: List[torch.Tensor]) -> List[torch.Tensor]:
"""All-gather tensors across workers with gradient support.

Single-process: returns the inputs unchanged. Multi-process: uses the
built-in differentiable ``torch.distributed.nn.functional.all_gather``,
so no custom ``autograd.Function`` is needed.

Args:
tensors (List[Tensor]): list of tensors to gather.

Returns:
List[Tensor]: gathered tensors, each (world_size * B, ...).
"""
if not dist.is_initialized() or dist.get_world_size() == 1:
return tensors
gathered: List[torch.Tensor] = []
for tensor in tensors:
tensor_all = dist_nn.all_gather(tensor) # differentiable, per rank
gathered.append(torch.cat(tensor_all, dim=0))
return gathered

@staticmethod
def _gather_bool_mask(mask: torch.Tensor) -> torch.Tensor:
"""All-gather bool mask across distributed workers."""
if not dist.is_initialized() or dist.get_world_size() == 1:
return mask
mask_list = [torch.zeros_like(mask) for _ in range(dist.get_world_size())]
dist.all_gather(mask_list, mask)
return torch.cat(mask_list, dim=0)

def _masked_cross_entropy(
self,
logits_a: torch.Tensor,
logits_b: torch.Tensor,
safe_labels: torch.Tensor,
pair_mask_f: torch.Tensor,
n_valid: torch.Tensor,
) -> torch.Tensor:
"""Masked cross-entropy on column-masked logits, row-masked average.

Args:
logits_a: (B, B_global) column-masked logits (view-a branch).
logits_b: (B, B_global) column-masked logits (view-b branch).
safe_labels: (B,) labels with non-pair rows fallback to a safe col.
pair_mask_f: (B,) float pair mask (1.0 = pair row).
n_valid: scalar pair-row count, clamped to >= 1.
"""
ce_a = F.cross_entropy(logits_a, safe_labels, reduction="none")
ce_b = F.cross_entropy(logits_b, safe_labels, reduction="none")
# Backstop against a non-finite upstream logit (e.g. overflowed scale).
ce_a = torch.nan_to_num(ce_a, nan=0.0)
ce_b = torch.nan_to_num(ce_b, nan=0.0)

return ((ce_a + ce_b) * pair_mask_f).sum() / (2 * n_valid)

def forward(
self,
embed_a: torch.Tensor,
embed_b: torch.Tensor,
embed_a_ori: torch.Tensor,
embed_b_ori: torch.Tensor,
pair_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute the masked pair-contrastive loss.

Args:
embed_a: (B, dim) reconstructed (decoder) output of view a.
embed_b: (B, dim) reconstructed (decoder) output of view b.
embed_a_ori: (B, dim) original embedding of view a.
embed_b_ori: (B, dim) original embedding of view b.
pair_mask: (B,) bool, True = contrastive-pair sample.

Returns:
Tensor: scalar mean of the three contrastive terms (self/ori/cl).
"""
logit_scale_self = self._scaled(self.logit_scale_self)
logit_scale_ori = self._scaled(self.logit_scale_ori)
logit_scale_cl = self._scaled(self.logit_scale_cl)

local_batch_size = embed_a.size(0)

# Labels carry the cross-rank offset, so refresh them on batch-size change.
if local_batch_size != self.last_local_batch_size:
self.labels = local_batch_size * self._rank + torch.arange(
local_batch_size, device=embed_a.device
)
self.last_local_batch_size = local_batch_size

embed_a = F.normalize(embed_a, dim=-1, p=2)
embed_b = F.normalize(embed_b, dim=-1, p=2)

# One batched all-gather for all four operands (gradient-preserving).
embed_a_all, embed_b_all, embed_a_all_ori, embed_b_all_ori = (
self._all_gather_with_grad([embed_a, embed_b, embed_a_ori, embed_b_ori])
)

pair_mask_all = self._gather_bool_mask(pair_mask)
col_mask = (~pair_mask_all).unsqueeze(0) # (1, B_global)

labels = self.labels
fallback = pair_mask.long().argmax() # first pair sample index
safe_labels = torch.where(pair_mask, labels, fallback.expand_as(labels))
pair_mask_f = pair_mask.float()
n_valid = pair_mask_f.sum().clamp(min=1)

# Three symmetric contrastive groups, each (scale, a-target, b-target):
# self: recon-a vs recon-b (vs the other recon view)
# ori: recon vs the counterpart original
# cl: recon vs its own-view original
groups = (
(logit_scale_self, embed_b_all, embed_a_all),
(logit_scale_ori, embed_b_all_ori, embed_a_all_ori),
(logit_scale_cl, embed_a_all_ori, embed_b_all_ori),
)
loss = embed_a.new_zeros(())
for scale, a_target, b_target in groups:
logits_a = scale * embed_a @ a_target.t()
logits_b = scale * embed_b @ b_target.t()
# Fill masked columns with the LOGITS dtype's most negative finite
# value: below any real logit (masks like -inf) but finite, so an
# all-non-pair row yields a finite CE/grad instead of 0*NaN. Derive
# it from the logits dtype, not the embeddings': under autocast the
# matmul casts to bf16/fp16 and finfo(embed.dtype=fp32).min would
# overflow masked_fill on the lower-precision logits.
neg_fill = torch.finfo(logits_a.dtype).min
logits_a = logits_a.masked_fill(col_mask, neg_fill)
logits_b = logits_b.masked_fill(col_mask, neg_fill)
loss = loss + self._masked_cross_entropy(
logits_a, logits_b, safe_labels, pair_mask_f, n_valid
)
return loss / 3
Loading
Loading