[feat] SID: add SidRqvae (RQ-VAE) model — STE/Gumbel-Softmax + CLIP#545
Conversation
Bring the SID-generation stack (from the remove_ema_2 working branch) onto a clean upstream base as the starting point for the base-class abstraction refactor. Net-new files only: - models: sid_rqvae.py, sid_rqkmeans.py (+ tests), _sid_helpers.py - modules/sid_generation: rqvae, residual_quantized, residual_kmeans, kmeans, vector_quantize, clip_loss, types - protos: sid_model.proto + SidRqvae/SidRqkmeans wired into model.proto 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (sid_model.proto:54), replace the stringly-typed list fields with the tzrec-conventional repeated numeric types: hidden_dims / codebook : string -> repeated uint32 latent_weight : string -> repeated float This moves validation to proto-load time, restores text_format type checking, and removes the ad-hoc tzrec/models/_sid_helpers.py shim (parse_int_list / parse_float_list). It also fixes the always-truthy `if cfg.latent_weight:` guard noted in review — an unset repeated field is an empty (falsy) list, so the signature-default branch is now real. Wrappers/tests updated to pass lists; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…enames Per review (residual_kmeans.py:29, residual_quantized.py:30), abstract the two residual-quantization backends behind a shared base and align names with the tzrec module convention: - new modules/sid_generation/residual_quantizer.py: ResidualQuantizer abstract base — owns embed_dim/n_layers/n_embed_list/normalize_residuals, the backend-agnostic decode_codes (via a _lookup_code primitive) and output_dim. Subclasses build self.layers and implement forward/get_codes/ get_codebook_embeddings/_lookup_code. - residual_quantized.py -> residual_vector_quantizer.py ResidualQuantized -> ResidualVectorQuantizer (VQ, gradient-trained) - residual_kmeans.py -> residual_kmeans_quantizer.py ResidualKMeans -> ResidualKMeansQuantizer (offline FAISS) - types.ResidualQuantizedOutput -> ResidualQuantizerOutput - rqvae.py / __init__.py / tests / docstrings updated to the new names. Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (sid_rqvae.py:211 — "add a parent class for rqkmeans and rqvae for shared init_metric and update_metric"), introduce tzrec/models/sid_model.py::BaseSidModel(BaseModel) and have both SidRqvae and SidRqkmeans inherit it. The base owns the structure the two models duplicated: - __init__ scaffolding: embedding_feature_name + codebook -> n_embed_list / n_layers. - _extract_feature(batch, feature_name=None) (replaces the per-model _extract_feature / _extract_embedding copies). - init_loss (SID losses are internal; no module to register). - init_metric registering the shared eval metrics (mse, unique_sid_ratio); subclasses call super().init_metric() then add extras (RQ-VAE: train-path mse; RQ-KMeans: rel_loss). - _update_unique_sid_ratio(codes) shared by both update_metric paths. - a default no-op update_train_metric (RQ-VAE overrides it). Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (sid_rqkmeans.py:1). Applies to the 14 net-new SID source/test files added in this branch. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (residual_kmeans.py:293 — "Why two layers of abstraction? I think RQKMeans is not needed"). RQKMeans was a thin nn.Module that just held a ResidualKMeansQuantizer and forwarded every call. SidRqkmeans now owns a ResidualKMeansQuantizer directly (self._quantizer); its forward returns (codes, quantized) so predict unpacks the tuple. Removed the RQKMeans class + export; updated tests/docstrings. Re review (sid_rqkmeans.py:88 — "use config_to_kwargs"): config_to_kwargs is currently broken framework-wide under protobuf 5.x (it passes the removed `including_default_value_fields` kwarg), so it raises on every config. Kept a version-safe MessageToDict with a NOTE pointing at the helper for when it's fixed. 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (rqvae.py:26 — "I think RQVAE should be refactored into sid_rqvae.py"). The encoder/decoder MLPs, the ResidualVectorQuantizer, and the CLIP head now live directly on SidRqvae; the forward_rqvae / forward_mixed / loss helpers become private model methods (_forward_rqvae / _forward_mixed / _recon_loss / _masked_recon_loss). Deleted modules/sid_generation/rqvae.py and its RQVAE export. Drops the dead bits the wrapper carried: the never-set _is_inference dispatch and the unreachable commitment_loss=None default (the proto always supplies "l2"). Behavior unchanged; 18/18 unit tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- vector_quantize.py:204 — use torch.std_mean instead of torch.var_mean +
rsqrt for the pre-Sinkhorn z-score (cleaner, equivalent).
- sid_rq{vae,kmeans}_test.py — drop the feature_groups that referenced
"item_emb" while features=[] (SID models read the dense feature directly
and never consume feature_groups); config is now internally consistent.
- sid_rqkmeans.py — document why predict buffers to host (.cpu()): the
full corpus is accumulated before one FAISS pass, so GPU residency would
OOM and faiss-cpu can't take CUDA tensors.
18/18 unit tests pass.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (__init__.py:12 — "Add unit tests for all modules and functions"). Adds colocated *_test.py covering the modules that previously only had indirect, model-level coverage: - kmeans_test.py: recon_diagnostics, _squared_euclidean_distance (+ chunked-equivalence), _kmeans / _residual_kmeans shapes, and the KMeansLayer load/predict/round-trip + mid-fit-checkpoint guard. - vector_quantize_test.py: VectorQuantize STE/Gumbel x l2/cosine x sinkhorn forward, STE gradient-to-input, eval plain-lookup. - residual_quantizer_test.py: normalize_n_embed, the abstract base's shared output_dim/decode_codes + NotImplementedError primitives, and both subclasses (ResidualVectorQuantizer / ResidualKMeansQuantizer) incl. the non-uniform-codebook reject and an offline FAISS fit. - clip_loss_test.py: single-process MaskedCLIPLoss (all-clip finite, all-recon zero, backward-to-embeddings) and _all_gather_with_grad single-process identity. 51/51 SID unit tests pass (18 model + 33 module). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…st test Per review (clip_loss.py:50 — "Why not directly use torch.distributed.nn.functional.all_gather?"). Replace the hand-rolled GatherLayer autograd.Function with torch.distributed.nn.functional .all_gather inside _all_gather_with_grad; its backward already sum-reduces the per-rank grads and returns this rank's slice, so the custom Function (and its GatherLayer export) are gone. Adds clip_loss_dist_test.py: a 2-rank multi-process test (NCCL on GPU when >=2 devices, else gloo/CPU) asserting all_gather forward values, the world_size-summed backward, and a MaskedCLIPLoss forward/backward across ranks. Validated on 2x GPU (NCCL). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (sid_rqkmeans_test.py:133 — the single-rank test never enters the DDP branch; the gather_object/broadcast/_is_initialized path was untested). Adds a 2-rank multi-process test (NCCL on GPU when >=2 devices, else gloo) that fills each rank's offline buffer, runs on_train_end, and asserts: every rank ends initialized with non-zero, cross-rank-identical (broadcast) centroids, and eval predict emits valid in-range codes. Empirically refutes the "gather_object is incompatible with NCCL" concern from review for the pinned torch: the path completes on NCCL (validated on 2x GPU). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…shake Per review (tiankongdeguiji, sid_rqkmeans.py:250 — "In synchronized training, some ranks shouldn't have empty data—this logic is redundant"). The dataset layer already guarantees it: file-based datasets enforce `num_files >= world_size` (tzrec/datasets/dataset.py raises otherwise), so in synchronized DDP training every rank receives at least one shard and reaches the gather with a non-empty buffer. The cross-rank all_reduce(MAX) empty-flag handshake was therefore dead insurance. Removed it: the DDP branch now goes straight to gather_object/fit/ broadcast. The single-process branch keeps a plain local empty-buffer no-op guard (not a collective) so on_train_end without a training pass still degrades gracefully. Verified: single-process unit tests (incl. empty-buffer no-op) and the 2-rank NCCL on_train_end test pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (clip_loss.py:75 — "refactor into tzrec/loss"). MaskedCLIPLoss is a generic contrastive loss, not a SID quantization primitive, and has no sid_generation dependencies — so it belongs with the other loss modules (focal_loss, jrc_loss, ...). - tzrec/modules/sid_generation/clip_loss.py -> tzrec/loss/clip_loss.py (+ colocated clip_loss_test.py / clip_loss_dist_test.py). - SidRqvae imports it from tzrec.loss.clip_loss; dropped the MaskedCLIPLoss re-export from sid_generation/__init__.py. Behavior unchanged; single-process + 2-rank NCCL clip tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ther Per review follow-up: - MaskedCLIPLoss now subclasses torch.nn.modules.loss._Loss (matching the tzrec/loss convention, e.g. BinaryFocalLoss) instead of bare nn.Module. - The module-level _all_gather_with_grad helper had MaskedCLIPLoss as its only (production) caller, so it becomes a private @staticmethod MaskedCLIPLoss._all_gather_with_grad alongside _gather_bool_mask. Tests updated to the static-method form; single-process + 2-rank NCCL clip tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…args Per review (tiankongdeguiji, sid_rqkmeans.py:88 — "use config_to_kwargs"). Replace the bespoke MessageToDict call with the project-standard helper that ~35 other models already use (rank_model, match_model, dlrm, ...). config_to_kwargs returns Struct numbers as floats, so _coerce_proto_numbers is kept to restore the ints faiss.Kmeans expects (niter/seed/nredo). Note: config_to_kwargs passes MessageToDict(..., including_default_value_ fields=...), which protobuf 5.x renamed/removed — so it (like every other config_to_kwargs caller in tzrec) requires protobuf 4.x. Validated on the supported env (protobuf 4.25.9): the SidRqkmeans suite passes. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Replaces the unbounded per-step offline buffer (every embedding .cpu()'d and kept) with Vitter Algorithm-R reservoir sampling into a fixed-size host buffer, fixing the rank-0 OOM risk on large corpora. Self-tuning cap: FAISS K-Means only ever consumes K*max_points_per_centroid points (it subsamples internally), so that is the target. New proto field train_sample_size (0 = auto) sets the global target; the per-rank cap is target/world_size so the gathered set on rank0 is ~target and FAISS does no further subsampling. With the default max_points_per_centroid=256 and K=256 that's ~65K rows/layer instead of the whole corpus. on_train_end now consumes the reservoir sample directly (gather -> fit -> broadcast) and releases it. Tests updated to the reservoir state; added test_reservoir_caps_memory. Algorithm verified uniform + capped in isolation; model paths validated on the remote (protobuf 4.x) below. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…rch Lloyd) Per review (kmeans.py:134 "Why not use Faiss KMeans instead?" and residual_quantized.py:188 "averaging KMeans centroids across ranks is meaningless"). - New kmeans.faiss_residual_kmeans(): FAISS residual K-Means warm-start, the same backend the offline RQ-KMeans fit uses. Replaces the torch-native _kmeans / _kmeans_plus_plus / _residual_kmeans (deleted — they were the O(K^2 N), non-deterministic, single-batch Lloyd path only init used). - ResidualVectorQuantizer.init_embed_ now fits on rank 0 only and dist.broadcast(src=0)s the codebook, so every rank starts identical — instead of all_reduce-averaging permutation-misaligned per-rank centroids. Tests: swapped the torch-kmeans unit tests for faiss_residual_kmeans + a kmeans_init=True seeding test; added a 2-rank NCCL test asserting the broadcast yields bit-identical codebooks across ranks (validated on 2x GPU). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per request. FAISS itself supports any K per instance; only our offline wrapper was restricted because train_offline reused a single faiss.Kmeans across layers (hence the uniformity assert). Now: - ResidualKMeansQuantizer.train_offline builds a fresh faiss.Kmeans per layer with that layer's K (index construction is a cheap O(K*D) alloc next to train(), so effectively free); uniformity assert removed. - SidRqkmeans reservoir cap now derives from max(n_embed_list) so the largest layer is fed K*max_points_per_centroid points (non-uniform would otherwise under-sample the big layer). - proto + docstrings updated; RQ-KMeans now matches RQ-VAE, which already supported per-layer K. Tests: swapped the "rejected" assert for non-uniform support + a non-uniform train_offline fit, and an end-to-end SidRqkmeans [8,4,16] test asserting the cap uses max(K) and per-layer codes stay in range. Module tests pass locally; model paths validated on the remote (protobuf 4.x) below. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Per review (residual_kmeans.py:229 "faiss can directly accept a torch tensor, do not need to convert numpy"). - import faiss.contrib.torch_utils and pass torch tensors to Kmeans.train / index.search directly — no numpy round-trips; centroids/codes flow as torch tensors. - Auto-select FAISS GPU compute when a faiss-gpu build is present (gpu=current_device, overridable via faiss_kmeans_kwargs['gpu']); falls back to CPU on faiss-cpu builds. The residual matrix stays a host tensor — FAISS streams only its subsampled (k*max_ppc) working set to the GPU, so we never hold (N,D) in VRAM (no A10 OOM risk). Same code path both ways. Measured ~80x faster training on GPU (0.2s vs 16s, N=262k/k=1024/niter=10). CPU path validated locally; GPU path validated on the H20 (faiss-gpu) below. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…eSidModel Both SID proto messages carry input_dim and normalize_residuals, and both models re-read them. Move the parsing into BaseSidModel.__init__ (alongside the already-shared embedding_feature_name and codebook), exposing self._input_dim / self._normalize_residuals. SidRqvae and SidRqkmeans now use the base attributes instead of re-reading cfg. No behavior change; sid_rqvae tests pass locally, full suite validated on the remote below. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…alias Follow-up to hoisting input_dim into BaseSidModel — there's no need for a local `input_dim = self._input_dim` alias; reference the base attribute directly in the encoder/decoder dims and init log. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The SID port brought the models/modules/protos but not the lifecycle wiring, so SidRqkmeans.on_train_end (which runs the FAISS fit) was never invoked by a real train_eval run — only the unit tests called it directly. A real run would finish with an unfit (zero) codebook and predict would emit all-zero codes. - BaseModel.on_train_end: add the no-op base hook (so every model has it). - main.py: call _model.on_train_end() after the train loop, and force the tail-save to fire afterwards (last_ckpt_step guard) so the post-hook state — e.g. the freshly fit FAISS codebook — is always persisted, even when the last in-loop checkpoint coincided with the final step. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Both computed the same per-loss_type reconstruction loss (mse/l1/cosine); the only difference was the reduction. Fold into a single _recon_loss with an optional per-row mask: no mask -> mean over all rows (== the old reduction="mean"); mask -> mean over the masked-in rows (the mixed recon+CLIP path). _forward_mixed now calls _recon_loss(..., recon_mask). Behavior unchanged; 12/12 sid_rqvae tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The standard and mixed-CLIP paths emitted the same two losses under different dict keys: reconstruction_loss/recon_loss (the reconstruction loss) and quantization_loss/commitment_loss (the RVQ commitment loss). That also made the logged metric names differ by mode. Standardize on reconstruction_loss + quantization_loss everywhere (matches the quantizer's ResidualQuantizerOutput.quantization_loss field). loss() now always emits reconstruction_loss + quantization_loss, plus clip_loss when use_clip. Left the commitment_loss= constructor arg (the loss-type knob) and the _recon_loss method name untouched. Tests updated; 12/12 pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The two files tested the same model at different scales — single-process (proto parse, reservoir, FAISS fit, checkpoints, non-uniform) vs a 2-rank multi-process on_train_end DDP path. They duplicated _make_batch and the model-config builder. Merge into one file (matching the tzrec convention of co-locating dist tests, e.g. checkpoint_util_test): - Shared module-level helpers: _make_batch(..., device="cpu") and _build_model(...); the unit class's _create_model now wraps _build_model + init_parameters, and the spawned DDP worker reuses _build_model. - SidRqkmeansOfflineTest (single-process) + SidRqkmeansDistTest (2-rank, NCCL on GPU else gloo) now live together; deleted sid_rqkmeans_dist_test.py. Logic unchanged from the previously remote-validated tests (8 unit + 1 DDP). Structure verified locally (imports, both classes, picklable worker); full run needs a protobuf-4.x env (config_to_kwargs), as before. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The all_initialized property had no production caller — forward() checks each layer's is_initialized individually, and SidRqkmeans never used it. It was referenced only by residual_quantizer_test. Removed it; the tests now check all(layer.is_initialized for layer in rkq.layers) inline. 15/15 residual_quantizer tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…X-traceable torchrec's inference pipeline symbolically traces the model. The `if N <= chunk_size:` branch in `_squared_euclidean_distance` keyed off the traced batch dim, raising `torch.fx.proxy.TraceError` during predict export. The chunked path was only reachable per-batch from `KMeansLayer.predict` (small N — the offline fit uses FAISS, not this function), so the chunking was unnecessary as well as FX-breaking. Simplify to a branch-free (x_sq + y_sq - 2 x@y.T).clamp(min=0). Drop the now-dead chunk_size param and its test; add an FX symbolic-trace regression test on ResidualKMeansQuantizer.forward. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sinkhorn and Gumbel-Softmax pick the code by two different rules: with Sinkhorn on, `ids = Q.argmax` (balanced optimal-transport assignment), while the Gumbel branch builds `emb` from argmax(-distances + noise) (nearest code). The two indices generally diverge, so the saved semantic ID would not match the codebook vector actually reconstructed and trained. STE has no such issue since it looks up embedding(ids) directly. Add a constructor assert forbidding the inconsistent combo (STE+Sinkhorn or Gumbel-without-Sinkhorn remain valid), retarget the gumbel test param to use_sinkhorn=False, and add a test that the rejected combo raises. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Both backends' forward loops did the same thing — per layer: normalize the residual, assign codes, look up the quantized vector, subtract it, accumulate — differing only in (a) how a layer produces (codes, quant) and (b) VQ's per-layer commitment loss. Consolidate the shared structure: - New abstract primitive `_quantize_layer(layer_idx, residual, temperature)` -> (codes, quant), the encode-direction mirror of `_lookup_code`. K-Means runs predict()+centroids (with the uninitialized-layer zero guard moved inside it); VQ runs the VectorQuantize layer and returns the raw, grad-carrying codebook vector. - New concrete `_residual_pass()` in the base drives the walk and returns (cluster_ids, aggregated, cumulative). The residual subtraction uses `quant.detach()` — required for VQ's gradient semantics, a no-op for K-Means (buffer lookup) — so one line serves both. - `get_codes` is now concrete in the base (mirrors decode_codes), so both subclasses drop their copies. VQ no longer routes get_codes through a full training forward (no wasted loss/STE). - Each forward shrinks to: K-Means returns the final sum; VQ adds init_embed_, maps the commitment loss over `cumulative`, and applies STE/rotation. The per-layer commitment loss stays in VQ.forward (mapped over the returned cumulative quants) rather than leaking a loss hook into the shared walk. Behavior is unchanged: the two forwards remain numerically identical and the predict path stays FX-traceable (the loop is now in `_residual_pass`). Tests: add forward-vs-get_codes consistency for both backends; existing VQ backward / FAISS-init / dist-broadcast / K-Means FX-trace / decode tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- decode_codes: seed the accumulator from the first lookup so device AND dtype follow the codebook, instead of pinning to fp32 (silently upcasting each layer's add under mixed precision). n_layers >= 1 is guaranteed. - BaseSidModel._update_unique_sid_ratio: guard B == 0 (empty final shard under DDP/TorchRec) to avoid ZeroDivisionError. - residual_quantizer_test: add a fake one-primitive subclass exercising the concrete residual walk the base owns — get_codes shape + aggregate == Σ quantized_i, the detach invariant (codebook grad flows, input gets none), the normalize_residuals branch, and decode_codes sum + codebook dtype. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ct_feature)
Address review (round 5): use the standard build_input / EmbeddingGroup path
instead of reading a single dense feature out of Batch.dense_features, so a SID
model can take multiple content embeddings + side-info in one feature group
(FORGE/PLUM motivation).
- BaseSidModel: add init_input/build_input + self.embedding_group (called from
__init__, as in every RankModel-based model); derive _input_dim from
group_total_dim(feature_group); remove _extract_feature.
- proto: drop input_dim (derived); embedding_feature_name -> feature_group
(default "deep"); ClipConfig {clip,is_clip_pair}_feature_name ->
{clip,clip_pair}_feature_group.
- sid_rqvae/sid_rqkmeans: predict via build_input; CLIP dual path reads the
paired + pair-flag groups.
- mock configs: drop input_dim/embedding_feature_name (feature_groups now
load-bearing). Unit tests reworked to real create_features + feature_groups.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- SidRqvae: validate at init that clip_feature_group's total dim equals the main feature_group dim (both share one encoder) — else fail fast instead of an opaque matmul shape error on the first contrastive forward; add a test. - add tzrec/tests/configs/sid_rqvae_clip_mock.config exercising the full CLIP path (deep + clip_image + clip_pair groups, clip_config, sid_clip_loss). - trim verbose/narrative comments to code-focused one-liners. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… CLIP guards Post-E2 review/cleanup of the SID models: - efficiency: update_metric reads predictions["recon_target"] instead of re- running build_input per eval step; SidRqkmeans exposes recon_target alongside x_hat in fitted-eval predictions. - merge: extract SidRqvae._rqvae_pass (encode->quantize->decode), used by _predict_rqvae once and _predict_mixed twice (was triplicated). - bug: CLIP eval mse/rel_loss now respect recon_mask, so they score the same (non-pair) rows the recon loss optimizes instead of all rows. - fail-fast: validate the main feature_group (has_group) and the CLIP groups exist, the paired group matches input_dim, and the pair-flag group is dim-1 — instead of opaque KeyError/matmul errors on the first forward. - docs: correct the update_train_metric + sid_integration_test docstrings that the EmbeddingGroup refactor made stale. - tests: cover the CLIP group guards + metric masking; restore the derived input_dim<1 guard test (via a 0-dim group). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…string Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…dist) Verified on test/sid_abstract (ee3ecdc), re-applied onto the current code: - alibaba#2: encoder/decoder use the framework MLP(hidden_units) + a bare trailing nn.Linear for the unbounded latent/recon projection (MLP always activates its last layer); behavior-preserving vs the removed private _build_mlp. - alibaba#10: VectorQuantizeLayer l2 distance uses torch.cdist(x, codebook, p=2).pow(2); drop the hand-rolled _squared_euclidean_distance helper. cdist matches it to fp32 noise incl. a coincident zero-distance point (finite grad, no NaN). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…home - residual_vector_quantizer: drop the no-op `distance_types = [d]*n_layers` list (distance_type is uniform; pass it directly). - infonce_loss: compute pair_mask.float()/n_valid once in forward() instead of re-deriving them in each of the 3 masked-CE calls. - vector_quantize: fix a stale docstring (the commitment loss moved to tzrec.loss.commitment_loss.CommitmentLoss; the named method no longer exists). - test placement: fold the stale-named SquaredEuclideanDistanceTest into VectorQuantizeTest (it tests the layer, not a removed helper); add sid_model_test.py covering the recon_loss factory + _masked_mean, which were only exercised end-to-end through the SidRqvae/SidRqkmeans subclass tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ctor
Regression from review item D5 ("derive apply_ste from forward_mode == STE").
In test/sid_abstract the RVQ residual walk called quantize(apply_ste=False), so
each per-layer step returned the RAW codebook vector, which flows into the
cumulative `latents` and carries gradient to the codebook. D5 dropped that
param and hard-wired the per-layer STE wrap to fire whenever forward_mode==STE,
so inside the (input-detached) walk `quantized = x + (q - x).detach()` detached
the codebook from `latents`. The commitment loss's codebook term (loss2,
||z_e.detach() - latents||^2) then had ZERO gradient to the codebook: it froze
at init, so as the encoder trained the commitment loss grew unbounded (recon
stayed fine — the aggregate STE still trains encoder+decoder around the frozen
codebook). Symptom: commitment_loss 8.7 -> 33 while recon ~0.01.
Fix: VectorQuantizeLayer.quantize() returns the raw codebook vector; the single
straight-through estimator is applied on the aggregate in
ResidualVectorQuantizer.forward (the only production user of the layer). This
matches test/sid_abstract's effective behavior (apply_ste=False for the walk).
Add a regression test asserting the RVQ's `latents` carries codebook gradient,
and update the standalone layer test to the new (raw-vector) contract.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
… comment Test audit (all in residual_vector_quantizer_test.py): - drop test_raises_on_too_few_points — duplicate of the N>=K guard already owned by kmeans_quantize_test.py (the test's own comment admits it comes from the shared faiss_kmeans_fit primitive). - drop test_decode_codes_shared_base — decode_codes is a base ResidualQuantizer method; residual_quantizer_test.py covers it (shape + sum + dtype, stronger). - drop test_get_codes_no_grad — shape-only and mis-named; get_codes shape is covered by the base walk test, and the eval no_grad contract by the retained test_forward_get_codes_consistent_eval. Comment audit: - fix the _quantize_layer comment that claimed the per-layer STE wrap is a "numeric no-op" — the codebook-gradient bugfix removed that wrap, so the layer returns the raw codebook vector (the single STE is on the aggregate). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Move the ~40-line CLIP read + validation block out of __init__ into a private _init_clip() helper, so __init__ reads as a clean sequence (super -> clip -> encoder/decoder/quantizer). Behavior-identical (same attributes set, same fail-fast ValueErrors); the helper flattens the nesting with an early return when CLIP is off, and documents that it must run after super().__init__() (it needs embedding_group / _input_dim). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
master advanced to 1.2.21; bump past the collision to stay ahead. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Thanks for the careful reviews @tiankongdeguiji — here's a consolidated summary of the changes. The smaller fixes are listed first, then the larger refactors and the one suggestion I'd handle a bit differently. All of these changes have been validated through end-to-end experiments and have no impact on model results. Smaller fixes (done)
Config-driven lossesI added a Input via
|
These describe what the code does NOT do (or used to do) rather than the code: - sid_model.proto: the two "(input_dim is not configured here — ...)" notes on SidRqvae / SidRqkmeans (input_dim is simply not a field). - residual_vector_quantizer.py / types.py: "the commitment loss is no longer computed inside the quantizer" — trimmed to just what `latents` is for. - sid_rqkmeans_test.py: "input_dim is no longer a config knob ... rather than an explicit input_dim=0" history note on the 0-dim-group guard test. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Set _clip_feature_group / _clip_pair_feature_group to None at the top, then assign the real values after the `if not self._use_clip: return` (where CLIP is guaranteed on). Drops the two `if self._use_clip else None` ternaries and never reads cfg.clip_config.* when clip_config is unset. Behavior-identical; the attributes are still always defined and the clip_config<->sid_clip_loss consistency check still runs before the return. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The three learnable contrastive temperatures (logit_scale_self/cl/logit_scale), their init/cap constants and the clamp+exp were on BaseSidModel and passed into the loss through the feats dict. They are loss-internal hyperparameters, so move them into MaskedInfoNCELoss: it declares the nn.Parameters in __init__ and does the clamp(<= ln 100) + exp inside forward. Because the loss is stored in the model's _loss_modules (an nn.ModuleDict), the parameters are still registered, trained and checkpointed exactly as before. BaseSidModel loses three params, the scaled() helper, the two module constants and the numpy/nn imports; the model just hands the loss the four embeds + the pair mask. Behavior-identical (same init, same clamp+exp); the overflow and large-scale tests now reach the temperatures on the loss module. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
/simplify follow-up to the temperature move: the previous commit inlined the clamp(<= ln 100) + exp idiom three times in forward(), dropping the single scaled() helper the pre-refactor sid_model.py had. Restore it as a _scaled() staticmethod so the cap constant + the clamp+exp contract live in one place. Also dedup the overflow test: hoist the three-temperature tuple into `scales` and loop over it for both fill_ and the grad assertions (matching the sibling infonce_loss_test). Behavior-identical. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| // Name of the main input FEATURE GROUP (built by the model's EmbeddingGroup | ||
| // from ModelConfig.feature_groups). May hold one or many content/side-info | ||
| // features; their concatenated dim is the K-Means dimension. | ||
| optional string feature_group = 40 [default = "deep"]; |
There was a problem hiding this comment.
not needed, auto detect group_names()[0]
Rename the SID losses to a modality-agnostic "contrastive" scheme, register all
three SID losses uniformly in _loss_modules, and tidy feature_group handling.
Behavior-preserving: same losses, gradients and init (the contrastive loss is
bit-identical to the prior version on forward + both grads).
Protos (loss.proto, models/sid_model.proto):
- messages ReconLoss/CommitmentLoss/SidClipLoss -> SidReconLoss/
SidCommitmentLoss/SidContrastiveLoss; oneof field sid_clip_loss ->
contrastive_loss (field numbers 6/7/8 unchanged).
- ClipConfig -> ContrastiveConfig; clip_feature_group -> pair_feature_group;
clip_pair_feature_group -> pair_flag_feature_group; clip_config ->
contrastive_config.
- feature_group: drop the "deep" default -> optional with single-group
auto-detect (DLRM-style); multiple groups must name it explicitly.
Losses:
- new SidReconLoss(_Loss) (mse/l1/cos, reduction="none") replaces the
recon_loss factory fn; recon is now a _loss_modules entry like the rest.
- commitment_loss.py -> sid_commitment_loss.py, CommitmentLoss ->
SidCommitmentLoss(_Loss) (aligned to _Loss for uniformity; behavior-neutral).
- infonce_loss.py -> sid_contrastive_loss.py, MaskedInfoNCELoss ->
SidContrastiveLoss: explicit (embed_a/b, embed_a/b_ori, pair_mask) args +
scalar return (was a feats dict / {"loss": ...}); one batched all-gather
(was two); the six logit/CE blocks DRY'd into a loop; logit_scale ->
logit_scale_ori.
Models:
- sid_model.py: uniform _sid_loss_impl dispatch; _resolve_feature_group
auto-detect; drop the recon factory + _recon_fn.
- sid_rqvae.py: _init_contrastive, attr renames (_use_contrastive /
_pair_feature_group / _pair_flag_feature_group), _predict_mixed locals
(is_pair_raw / pair_mask); drop CLIP/image-text docstring framing.
- sid_rqkmeans.py unchanged (inherits the renamed base).
Tests/configs: colocated test renames + a new sid_recon_loss_test.py; the recon
factory tests move out of sid_model_test.py; sid_rqvae_clip_mock.config ->
sid_rqvae_contrastive_mock.config (new fields + explicit feature_group).
Run gen_proto.sh after this change to regenerate the *_pb2 modules.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…loss/base Comment-refinement pass over the refactor: drop the temperature/L2-normalize restate comments in SidContrastiveLoss.forward and the "shared config fields" section label in BaseSidModel; reword the label-refresh comment to state why (it carries the cross-rank offset). Non-obvious-trick comments (overflow clamp, NaN backstop, neg_fill dtype, operand pairing, fail-fast, trailing-Linear) are kept. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Remove inline shape annotations (# (B, D) etc.) and other comments that merely restate the adjacent line — the wider codebase does not use inline shape comments, and they are derivable. Kept: public docstrings, and comments that explain a non-obvious trick or rationale (the Eq 4.2 Householder note, the trailing-Linear / unbounded-projection rationale, "first training forward only", "residual, in place", "differentiable", the contrastive col-mask purpose and safe-labels fallback, the finfo/overflow and detach-swap notes). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Stricter comment pass: strip comments that narrate control flow / structure or restate the adjacent code (the "default to no contrastive path", "codes-only path", Sinkhorn "Step 1..4" labels, "first training forward only", the gumbel/sinkhorn auto-disable narration whose logged warning already says it, etc.). Condense the trailing-Linear and codebook-freeze notes. Kept only what guards a real mistake: comments whose removal would let a future edit reintroduce a bug (the codebook-freeze "no per-layer STE wrap", the DDP rank-0 faiss-broadcast deadlock ordering, the sinkhorn_epsilon>0 overflow and non-negative-cost requirements), external refs (Eq 4.2 Householder), the non-obvious contrastive operand-pairing table, and public docstrings. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Run ruff format: stripping a trailing shape comment left the rescaled_embeddings expression short enough to fit on one line, so the formatter collapses the parenthesized form. Fixes the RunCodeStyleCI (ruff-format) failure on the prior commit. No behavior change. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…uto-detect) - SidReconLoss now owns the masked-mean reduction: forward(x_hat, x, mask) returns the scalar reconstruction loss, so all three SID loss modules take their operands/mask and return a scalar uniformly. Drop _masked_mean + div_no_nan from sid_model.py; its tests move to sid_recon_loss_test.py (sid_model_test.py held only those, so it is removed). - Drop the SidRqvae/SidRqkmeans `feature_group` proto field and _resolve_feature_group: the main input is just group_names()[0] (the first declared feature group), per the maintainer's "auto detect group_names()[0]". Single-group models need no field; the contrastive path names its paired groups in contrastive_config. Configs/tests drop the explicit feature_group. - Audit fixes in these files: stale "CLIP feature wiring" proto comment -> "contrastive"; rename a leftover `clip` test local to `contrastive`. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
test_phase2_replacement asserted only ``(idx >= cap).any()``, which is near-tautological (it catches replacement being disabled outright but misses a badly-low accept probability). Require the phase-2 survivor count to exceed cap // 2 (the expected count is ~= cap). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Summary
Adds SidRqvae, an end-to-end differentiable RQ-VAE semantic-ID model, on top of the SID foundation already in master (
BaseSidModel/ResidualQuantizerfrom #538,SidRqkmeansfrom #539). It complements the encoder-free, FAISS-trainedSidRqkmeanswith a gradient-trained encoder→quantizer→decoder that learns the codebook jointly with a reconstruction objective, and optionally a pair-contrastive (masked InfoNCE) objective.The model runs on both CPU and GPU (gradient training, no FAISS dependency at train time unless
kmeans_initwarm-start is requested).What's in scope
Config-driven losses. SID losses are configured through
ModelConfig.losses: aLossConfigcarries onesid_lossoneofterm per objective (recon_loss,commitment_loss,contrastive_loss), andBaseSidModelregisters each as a_loss_modulesentry and computes them centrally (mirroringRankModel). Each is a small_Lossmodule that takes its operands/mask and returns a scalar:SidReconLoss(tzrec/loss/sid_recon_loss.py) — per-row reconstruction distance (recon_type∈ {l2,l1,cos}), masked-mean reduced (the mixed recon+contrastive path scores reconstruction rows only).SidCommitmentLoss(tzrec/loss/sid_commitment_loss.py) — VQ-VAE commitment between the encoder output and the per-layer cumulative quantized vectors (latent_weight[w1, w2],commitment_type).SidContrastiveLoss(tzrec/loss/sid_contrastive_loss.py) — masked InfoNCE over mixed reconstruction+pair batches, driven by a per-row pair flag. Modality-agnostic (no image/text assumption), with structural masked-logit fill (dtypefinfo.min, AMP-safe), learnable per-group temperatures (clamped beforeexp), and built-in differentiable all-gather for DDP.Models
SidRqvae(tzrec/models/sid_rqvae.py): configurable-depth encoder/decoder MLPs (mirrored; frameworkMLP+ a trailing bareLinearso the latent/reconstruction stay unbounded), an N-layer residual vector quantizer, and an optional contrastive dual-encoder path. Knobs:forward_mode∈ {ste,gumbel_softmax},normalize_residuals,distance_type,rotation_trick(arXiv:2410.06424), Sinkhorn balanced assignment, and FAISSkmeans_initwarm-start (opt-in, default off).EmbeddingGroup/build_input, so SID models consume the same grouped feature tensor as every other model and support multiple content + side-info embeddings per FORGE/PLUM. The main input is the (sole, or first-declared) feature group — auto-detected, nothing to configure; the contrastive path names its paired + pair-flag groups incontrastive_config.Modules
ResidualVectorQuantizer(modules/sid/residual_vector_quantizer.py): gradient-trained N-layer residual VQ. STE returns the raw codebook vector and applies one aggregate straight-through on the encoder side (the codebook trains via the commitment loss); Gumbel-Softmax's soft assignment is differentiable directly. Includes the FAISS residual warm-start (faiss_residual_kmeans) with a DDP fail-together broadcast that avoids the rank-0-raises-while-others-hang deadlock.VectorQuantizeLayer(modules/sid/vector_quantize.py): single-layerQuantizeLayerwith STE / Gumbel-Softmax / Sinkhorn balanced assignment (Sinkhorn auto-disabled under Gumbel;sinkhorn_epsilonguarded> 0to avoid the exp-kernel overflow).faiss_kmeans_fitprimitive inmodules/sid/kmeans_quantize.py, used by both the RQ-VAE warm-start and the RQ-KMeans offline fit (one-layer FAISS fit with anN >= Kguard).Protos
loss.proto:SidReconLoss/SidCommitmentLoss/SidContrastiveLossmessages + thesid_lossoneofonLossConfig.models/sid_model.proto:SidRqvae,SinkhornConfig, andContrastiveConfigmessages. The main feature group is auto-detected (group_names()[0]), so there is nofeature_groupfield.model.proto:sid_rqvaewired into theModelConfigmodeloneof.Test plan
Colocated unit tests for every new/changed module (distributed paths exercised by 2-rank cases inside the same
*_test.py), plus an end-to-end integration test:sid_rqvae_test.py— model forward/loss/metric, STE + Gumbel, with/without the contrastive path, the fail-fast group/dim guards, and the logit-scale overflow clamp.sid_recon_loss_test.py,sid_commitment_loss_test.py,sid_contrastive_loss_test.py— the three loss modules (per-row distance + masked reduction, commitment directions, masked contrastive incl. empty-mask safety, 2-rank all-gather, AMP-safe fill).residual_vector_quantizer_test.py— Gumbel grad flow vs the STE codebook gradient (trained via commitment), FAISS warm-start, non-uniform codebooks.vector_quantize_test.py— STE / Gumbel / Sinkhorn assignment.residual_quantizer_test.py,kmeans_quantize_test.py,residual_kmeans_quantizer_test.py,quantize_layer_test.py— the shared quantizer primitives.sid_integration_test.py— RQ-VAE and RQ-KMeans end-to-end (train → eval → predict).All CPU tests green via
python -m tzrec.tests.run; GPU/DDP-only cases gate on CUDA. FAISS-dependent tests skip cleanly when faiss is absent.