Skip to content

[feat] Kafka: event-time driven checkpointing from message timestamp#541

Merged
tiankongdeguiji merged 12 commits into
alibaba:masterfrom
tiankongdeguiji:feat/kafka-event-time-checkpoint
Jun 9, 2026
Merged

[feat] Kafka: event-time driven checkpointing from message timestamp#541
tiankongdeguiji merged 12 commits into
alibaba:masterfrom
tiankongdeguiji:feat/kafka-event-time-checkpoint

Conversation

@tiankongdeguiji

@tiankongdeguiji tiankongdeguiji commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

Motivation

For online/streaming training from Kafka, KafkaReader previously surfaced only the message offset (as the resume cursor); the Kafka message timestamp was never read, so there was no way to checkpoint based on the event-time of the consumed data. This PR surfaces the message timestamp and uses it to drive checkpoint saves on consumed event-time, which is what you want for time-aligned snapshots of a stream.

What it does

Surfaces the Kafka message timestamp on the batch and adds event-time–driven checkpoint triggers, alongside the existing step/epoch triggers:

  • batch.data_timestamp — the max consumed event-time in the batch, as Unix-epoch seconds (float; -1.0 when the source has no timestamps). KafkaReader normalizes its ms timestamp to seconds; BaseDataset._build_batch just surfaces the per-batch max (unit-agnostic, so non-Kafka datasets are unaffected).
  • New TrainConfig fields:
    • save_checkpoints_timestamp_interval — save every N seconds of consumed event-time, aligned to the Unix epoch (wall-clock), not training epochs. 0 disables.
    • save_checkpoints_timestamps — absolute event-time targets (Unix-epoch seconds); save once when consumed data crosses each. Empty disables.
    • save_checkpoints_timestamp_quorum — fraction of workers (0, 1], default 0.5, that must have consumed past a boundary/target before a save fires.

Design

  • Centralized save policy. All save cadence (step / epoch / event-time) + the single per-step dedupe + the event-time watermark live in CheckpointManager.maybe_save(...) -> bool; the training loop calls it at the step / epoch / final points and runs eval when it returns True. This removes the four hand-guarded save sites in the loop (the source of an epoch-vs-timestamp double-save) and makes the decision logic unit-testable.
  • Worker quorum (not min/max). Each rank passes its local event-time to maybe_save, which all-gathers across workers and reduces to the (1 - quorum) upper quantile — the event-time that ≥ quorum fraction of workers have reached. This generalizes min (quorum=1.0) / max (quorum→0), is robust to a single outlier/garbage timestamp at the default 0.5, and counts a worker without a timestamp (-1.0) as "not past" so the quorum is over all workers.
  • Off the GPU. The per-step reconciliation gathers over a dedicated CPU (gloo) process group that CheckpointManager creates in set_save_policy and owns, so it never runs on the model device or forces a per-step device sync that would stall the TrainPipelineSparseDist prefetch.
  • Resume-correct. The event-time watermark is persisted in dataloader_state (saved every checkpoint, restored on continue_train) and advanced on every save, so resume re-fires no already-saved boundary and misses none. It's a single global float, identical across ranks (the per-key max-merge is a no-op on it) and ignored by the existing dataloader_state consumers. dataloader_state is typed Dict[str, Any] since it now holds the float watermark alongside the int offsets.
  • Distributed-safe. The reconciliation is gated only by config (identical on all ranks) and the save decision is deterministic + identical across ranks, so the collective save is entered in lockstep; nothing collective runs when not saving. It relies on the same equal-step-count invariant the model's own gradient collectives already require.
  • No behavior change when unset — with the new config fields unset, training is identical to before.

Tests

  • Pure helpers quorum_event_time (quantile incl. outlier/sentinel cases) and should_save_on_timestamp (interval/target crossings, no-refire, first-batch init) — parameterized.
  • CheckpointManager.maybe_save — step/epoch/timestamp triggers, the dedupe (incl. the epoch-after-same-step regression), watermark stamping, restore seeding, and the -1.0-sentinel path; plus _reconcile_event_time single-process behavior (save mocked, no GPU/dist needed).
  • _build_batch data_timestamp extraction; Kafka roundtrip asserts a positive seconds value.
  • Docs: the three new TrainConfig fields documented in train.md.

🤖 Generated with Claude Code

Surface the Kafka message timestamp on the batch (Batch.data_timestamp,
ms) and use it to trigger checkpoint saves based on consumed event-time:

- save_checkpoints_timestamp_interval: save every N seconds of consumed
  event-time, aligned to epoch boundaries so cut points are deterministic
- save_checkpoints_timestamps: save once when consumed data crosses each
  configured absolute timestamp
- checkpoint_timestamp_reduce: min/max cross-rank reconciliation of the
  per-rank consumed event-time

The KafkaReader now reads msg.timestamp() alongside offset/partition and
carries it in a transient __data_timestamp__ column that _build_batch
collapses into Batch.data_timestamp. The timestamp is not persisted:
dataloader_state.json stays {source: offset}, so resume/restore and the
non-kafka (file/odps) datasets are unchanged.

The per-step cross-rank all-reduce that reconciles event-time is gated
only by the config-derived ts_trigger_active, so every rank participates
identically and the collective stays deadlock-safe; ranks without a
timestamp contribute a neutral sentinel.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Jun 8, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 8, 2026
Comment thread tzrec/main.py Outdated
Comment on lines +522 to +527
if dist.is_initialized():
ts_tensor = torch.tensor(
[local_ts_s], dtype=torch.float64, device=ts_device
)
dist.all_reduce(ts_tensor, op=reduce_op)
global_ts_s = ts_tensor.item()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When timestamp-checkpointing is active, this runs a synchronous all_reduce + .item() every training step. The .item() forces a host/GPU sync that isn't overlapped with compute, and unlike the gradient all-reduce it can't piggyback on the backward pass — so it partially defeats TrainPipelineSparseDist's prefetch/overlap. Checkpoint decisions only need coarse event-time resolution; consider reducing every N steps (with N derived identically on all ranks so the collective stays in lockstep) or using async_op=True and reading the result one step later.

Separately, this per-step collective assumes all ranks run the same number of steps. That invariant is already required by the model's fwd/bwd collectives, but it's worth noting that with Kafka's uneven partition data and check_all_workers_data_status=False (i.e. no batch_cost_size), a rank whose partitions drain first hits StopIteration and breaks before this all_reduce, hanging the others. The "every rank participates identically" comment holds only under that equal-step-count invariant — worth documenting that this feature needs batch_cost_size/even data, or gating the all-reduce on check_all_workers_data_status.

Comment thread tzrec/main.py Outdated
Comment on lines +554 to +556
if save_checkpoints_epochs > 0 and i_step > 0:
if (i_epoch + 1) % save_checkpoints_epochs == 0:
last_ckpt_step = i_step
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
if eval_dataloader is not None:
_evaluate(
model,
eval_dataloader,
eval_config,
eval_result_filename=eval_result_filename,
global_step=i_step,
eval_summary_writer=eval_summary_writer,
global_epoch=i_epoch,
check_all_workers_data_status=check_all_workers_data_status,
)
model.train()
do_checkpoint(i_step)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The step-trigger and timestamp-trigger both guard against re-saving a step already saved (i_step != last_ckpt_step), and so does the final save below — but the epoch-trigger doesn't. If a timestamp-triggered save fired on the last step of the epoch (so last_ckpt_step == i_step), this re-saves the same step: a redundant full checkpoint write and a redundant eval.

Suggested change
if save_checkpoints_epochs > 0 and i_step > 0:
if (i_epoch + 1) % save_checkpoints_epochs == 0:
last_ckpt_step = i_step
ckpt_manager.save(i_step, model, optimizer, dataloader_state)
if eval_dataloader is not None:
_evaluate(
model,
eval_dataloader,
eval_config,
eval_result_filename=eval_result_filename,
global_step=i_step,
eval_summary_writer=eval_summary_writer,
global_epoch=i_epoch,
check_all_workers_data_status=check_all_workers_data_status,
)
model.train()
do_checkpoint(i_step)
if save_checkpoints_epochs > 0 and i_step > 0 and i_step != last_ckpt_step:
if (i_epoch + 1) % save_checkpoints_epochs == 0:
do_checkpoint(i_step)

Alternatively, centralize the guard inside do_checkpoint (early-return when step == last_ckpt_step) so no call site can forget it.

Comment thread tzrec/protos/train.proto Outdated
Comment on lines +75 to +77
// event-time interval (seconds) for saving checkpoint, based on the consumed
// data timestamp (e.g. kafka message timestamp). saves once per epoch-aligned
// interval boundary that the consumed event-time crosses. 0 disables.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"epoch-aligned" is misleading here — alongside save_checkpoints_epochs/num_epochs it reads as "aligned to a training epoch", but the code is floor(ts / interval_s), i.e. boundaries are multiples of interval_s since the Unix epoch (wall-clock), unrelated to training epochs. Same wording appears in checkpoint_util.py (should_save_on_timestamp docstring/comments) and train.md ("按epoch对齐") — worth fixing in all three.

Suggested change
// event-time interval (seconds) for saving checkpoint, based on the consumed
// data timestamp (e.g. kafka message timestamp). saves once per epoch-aligned
// interval boundary that the consumed event-time crosses. 0 disables.
// event-time interval (seconds) for saving checkpoint, based on the consumed
// data timestamp (e.g. kafka message timestamp). saves once per interval
// boundary (aligned to the Unix epoch, not training epochs) crossed. 0 disables.

@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

Review summary

Clean, well-documented feature. The do_checkpoint(step) de-duplication of the four save sites is a real readability win, should_save_on_timestamp is pure and well unit-tested, and the transient __data_timestamp__ column is correctly popped before feature parsing so it can never leak into features or checkpoint state. When the feature is unset, the training path is unchanged. Nice work.

I left inline comments on three points; the substantive ones:

  • Per-step all_reduce + .item() (main.py): when active, a synchronous collective + host sync runs every step and isn't overlapped with compute. Consider reducing every N steps (N identical across ranks) or async_op. Also worth documenting the equal-step-count invariant for uneven Kafka partitions (batch_cost_size/check_all_workers_data_status).
  • Epoch-trigger double-save (main.py): the epoch save lacks the i_step != last_ckpt_step guard the other sites have, so it can re-save (checkpoint + eval) a step a timestamp-trigger already saved.
  • "epoch-aligned" wording (train.proto / checkpoint_util.py / train.md): means Unix-epoch-aligned, not training-epoch-aligned — easy to misread.

A couple of additional notes not worth inlining:

  • Test gaps (no Kafka needed): the _build_batch data_timestamp extraction (None when column absent, the >= 0 filter that drops -1 "no timestamp" batches, max selection) and Batch.to()/pin_memory() propagation of data_timestamp are both untested and both fail silently (event-time checkpointing becomes a no-op) rather than loudly. Both are unit-testable via the existing _TestReader/_TestDataset harness. The checkpoint_timestamp_reduce validation (ValueError on bad value) is also a cheap test to add.
  • Untrusted timestamps: dataset.py only rejects negative values. A single far-future/garbage msg.timestamp() advances last_ckpt_data_ts_s and can then silently suppress all future saves (mitigated under the default min reduce, exposed under max). A sanity bound + warning would harden the streaming path.

…uorum

Address PR review on event-time checkpointing.

Centralize save-cadence in CheckpointManager.maybe_save():
- step / epoch / event-time triggers + a single dedupe live in one place, so
  no save site can re-save a step another already saved (fixes the epoch
  double-save at its root, and makes the decision logic unit-testable).
- the event-time watermark gets its natural home alongside the manager's
  dataloader_state IO.

Replace the min/max cross-rank reduce with a worker quorum:
- save_checkpoints_timestamp_quorum (float, default 0.5): save once >= quorum
  fraction of data-carrying workers have consumed past the boundary/target,
  computed as the (1-quorum) upper quantile of per-worker event-times.
- generalizes the old knobs (quorum 1.0 == min, near 0 == max) and is robust to
  a single outlier/garbage timestamp at the default, so no wall-clock guard is
  needed.

Make resume correct: persist the event-time watermark in dataloader_state and
seed it on continue_train, so resume re-fires no already-saved boundary and
misses none. The cross-rank reconciliation uses all_gather (not all_reduce) so a
NaN sentinel for a rank without a timestamp never poisons the others; the loop
does the gather, the manager does the quorum + decision (no collective when not
saving).

Clarify wording: the interval is aligned to the Unix epoch (wall-clock), not to
training epochs.

Add unit tests for quorum_event_time, maybe_save (incl. the epoch-dedupe
regression), Batch.data_timestamp extraction in _build_batch, and
Batch.to()/pin_memory() propagation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji

Copy link
Copy Markdown
Collaborator Author

Thanks for the review — addressed in 2dd3dcc. Summary of how each point was handled:

Epoch-trigger double-save (real bug) — fixed at the root. Rather than adding the missing guard to one more site, the save-cadence logic (step + epoch + event-time) is now centralized in CheckpointManager.maybe_save(...), which owns last_ckpt_step and does a single dedupe. No call site can re-save a step another already saved. This also makes the decision logic unit-testable (it wasn't, buried in the loop) — added a regression test test_maybe_save_dedupe_epoch_after_same_step.

Per-step all_reduce + .item(). The cross-rank reconciliation is now an all_gather driving a worker quorum (replacing the min/max knob): save_checkpoints_timestamp_quorum (float, default 0.5) — save once ≥ quorum fraction of data-carrying workers have consumed past the boundary/target, computed as the (1-quorum) upper quantile. It generalizes the old knobs (1.0≡min, →0≡max) and is robust to a single far-future/garbage timestamp at the default, so the "untrusted timestamp" concern needs no separate wall-clock guard. (Kept per-step per design discussion; the loop does the gather, the manager does the quorum + decision with no collective when not saving.)

"epoch-aligned" wording. Fixed in train.proto, checkpoint_util.py, and train.md to clarify it means the Unix epoch (wall-clock), not training epochs.

Test gaps. Added: quorum_event_time (quantile edge cases incl. outlier robustness), maybe_save (step/epoch/timestamp/final + dedupe + watermark + restore seeding), _build_batch data_timestamp extraction (max valid / -1 filter / absent→None), and Batch.to()/pin_memory() propagation (the silent-failure guard).

"Uneven partitions → deadlock". Not a new deadlock: a rank that drains first hangs at the model's own gradient all-reduce inside pipeline.progress, which precedes our collective — the equal-step-count invariant is already required by the existing collectives. Corrected the overclaiming comment to say so.

New: resume correctness. The event-time watermark is now persisted in dataloader_state and seeded on continue_train, so resume re-fires no already-saved boundary and misses none (advanced on every save). It's a single global float, identical across ranks, and ignored by every existing dataloader_state consumer.

tiankongdeguiji and others added 5 commits June 8, 2026 15:02
Make the event-time feature single-unit. Batch.data_timestamp is now
Unix-epoch seconds (float), matching the config (interval/targets in seconds),
the persisted watermark, the trigger math, and Python time.time(). Normalization
from the raw kafka ms happens once in _build_batch (max_ts / 1000.0); the
internal __data_timestamp__ arrow column keeps the raw ms. The all_gather no
longer needs a /1000 conversion. float64 preserves full ms precision at
epoch-seconds magnitude.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…d_batch

The ms->seconds conversion is Kafka-specific, so move it out of the generic
BaseDataset._build_batch and into KafkaReader where the ms timestamp originates.
The __data_timestamp__ column is now float64 Unix-epoch seconds (-1 sentinel for
unavailable); _build_batch just surfaces the per-batch max valid value and stays
unit-agnostic (the producing reader owns the unit).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse the repetitive single-assertion tests into @parameterized.expand rows
(param("name", **kwargs)): should_save_on_timestamp, quorum_event_time, the
_build_batch data_timestamp extraction, and Batch copy (to/pin_memory)
propagation. The stateful maybe_save sequence tests stay as individual cases.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Make data_timestamp a plain float defaulting to -1.0 ("unavailable") rather
than Optional[float]/None. The -1.0 sentinel is now uniform end-to-end: kafka's
own -1 -> __data_timestamp__ column -> Batch.data_timestamp -> filtered by the
>= 0 check in the gather. This drops the None/NaN special-casing in
_gather_worker_event_times (now a simple >= 0 filter) and removes the unused
math import in main.py.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ments

- Extract the per-worker event-time all-gather into a generic
  dist_util.gather_float_scalar(value, device); it no longer filters the -1.0
  sentinel. The downstream quorum/save logic handles -1.0: a worker without a
  timestamp sorts low (counts as "not past"), so the quorum is over all workers,
  and maybe_save treats a negative quorum result as "no event-time this step".
- Add Args/Returns docstrings to the public CheckpointManager methods
  (maybe_save, set_save_policy) and gather_float_scalar; tighten the inline
  comments added in this PR.
- Drop the now-redundant >= 0 check in _build_batch (the -1.0 sentinel sorts
  below any real time, so max already yields the right value).
- Remove the Batch.to()/pin_memory() propagation test.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label Jun 8, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label Jun 9, 2026
Comment thread tzrec/main.py Outdated
Comment on lines +493 to +497
worker_ts = (
gather_float_scalar(batch.data_timestamp, ts_device)
if need_ts
else None
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When timestamp checkpointing is enabled, this runs every step: gather_float_scalar builds a CUDA float64 tensor, does an all_gather_into_tensor, then .tolist() — a blocking D2H sync. The baseline loop has no per-step host sync (train metrics only .item() every log_step_count_steps), so this new sync caps how far TrainPipelineSparseDist can run ahead to prefetch/input_dist the next batch and can cost throughput — exactly in the streaming-Kafka scenario where this feature is on.

Since saves are rare relative to steps, consider gathering/deciding only every N steps (keep it a deterministic function of i_step so it stays in lockstep across ranks), or running the gather on a gloo/CPU PG so the .tolist() doesn't synchronize the compute stream. No impact when the feature is disabled.

Comment on lines +348 to +351
if (ts_interval_s > 0 or len(ts_targets) > 0) and not (0.0 < ts_quorum <= 1.0):
raise ValueError(
f"save_checkpoints_timestamp_quorum must be in (0, 1], got {ts_quorum}."
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This quorum-range validation (and its ValueError) has no test, and gather_float_scalar — the new distributed primitive that feeds the trigger — is entirely untested (there's no dist_util_test.py). Both are cheap and high-value:

  • assertRaises(ValueError) for quorum<=0 / quorum>1 with a trigger active, plus a no-raise case when both triggers are off (the conjunction here is exactly what silently breaks on refactor).
  • At least the non-distributed gather_float_scalar path (return [value]); the gloo mp.spawn pattern already in checkpoint_util_test.py can cover the distributed rank-order path.

# advance + persist the watermark on every save so resume is exact
self._last_data_ts = data_ts
if dataloader_state is not None:
dataloader_state[DATA_TS_WATERMARK] = data_ts

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: dataloader_state is annotated Dict[str, int] throughout (save, maybe_save, save_dataloader_state, update_dataloder_state), but this stores a float watermark. Functionally harmless — the cross-rank max-merge in save_dataloader_state is a no-op since the value is rank-identical — but in this pyre-checked codebase the annotation is now inaccurate. Consider widening to Dict[str, Union[int, float]].

Comment on lines +66 to +68
- save_checkpoints_timestamp_interval: 按数据时间(如Kafka消息的timestamp)保存模型的间隔秒数,每当已消费数据的事件时间跨过一个间隔边界时保存一次;边界按 Unix 时间(epoch/墙钟)对齐,而非训练 epoch。默认0表示关闭。仅对带timestamp的数据源(如KafkaDataset)生效
- save_checkpoints_timestamps: 按数据时间保存模型的绝对时间点列表(单位为秒的 Unix 时间戳),已消费数据的事件时间每跨过一个时间点保存一次,默认为空表示关闭
- save_checkpoints_timestamp_quorum: 分布式训练下各rank消费的分区不同,事件时间也不同;当至少该比例(取值(0,1\],默认0.5)的rank已消费越过边界/时间点时才触发保存。1.0表示所有rank都越过才保存(最保守),越小越激进(最小可仅需一个rank);默认值对单个异常/超前的timestamp具有鲁棒性

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The save_checkpoints_steps/epochs bullets above note 保存模型后会做一次评估, but these three timestamp bullets don't. Timestamp-triggered saves also run an eval (run_eval fires whenever maybe_save returns True), so worth adding the same note for parity.

@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

Review summary

Reviewed across code quality, performance, tests, docs, and distributed safety. This is a well-built PR — the centralization of all save cadence into CheckpointManager.maybe_save(...) is a real improvement (it removes the three near-identical save+eval blocks and the epoch-vs-timestamp double-save), the quorum quantile is a nice robustness choice over min/max, and the resume path (persisted watermark seeded in restore_dataloader_state, untouched by set_save_policy) is correct. I verified the key distributed-safety claim: with timestamp checkpointing on, the per-step gather + the maybe_save decision are a deterministic function of identical all-gathered inputs on every rank, and TrainPipelineSparseDist._next_batch drops the remainder on all workers together, so the collective save stays in lockstep — no deadlock from uneven/exhausted data. Defaults disable every trigger, so existing runs are unaffected.

Noteworthy items, posted inline:

  • Perf (main only when enabled): the per-step gather_float_scalar adds a per-step .tolist() D2H sync the baseline loop doesn't have — can throttle the prefetch pipeline in the streaming case this targets. Suggest gathering/deciding every N steps or on a CPU PG.
  • Tests: the set_save_policy quorum ValueError path and gather_float_scalar (both paths) are untested — cheap, valuable additions.
  • Types: dataloader_state is Dict[str, int] but now carries a float watermark (harmless, but inaccurate under pyre).
  • Docs: the timestamp bullets in train.md should mention that a save also runs an eval, for parity with the step/epoch bullets.

One optional hardening (not blocking): the watermark read from dataloader_state.json on resume is fed straight into int(... // interval_s) without a type/finiteness check. The framework only ever writes valid floats, so this only bites on a hand-edited/corrupted state file (a loud crash at the first resumed step), but a float(...) + math.isfinite guard in restore_dataloader_state would make resume robust to it.

Nothing here blocks merge.

tiankongdeguiji and others added 5 commits June 9, 2026 11:12
…e type

- gather_float_scalar now all-gathers via all_gather_object over a passed
  process group instead of an all_gather_into_tensor + .tolist() on the model
  device. The train loop creates a dedicated gloo group for it, so the per-step
  event-time reconciliation stays on CPU and never forces a GPU device sync that
  would stall the TrainPipelineSparseDist prefetch in the streaming case.
- Widen the dataloader_state annotations to Dict[str, Any] (save / maybe_save /
  save_dataloader_state / update_dataloder_state / restore / load_state_dict),
  since it now holds the float event-time watermark alongside the int offsets;
  the old Dict[str, int] was inaccurate under pyre. checkpoint_info stays
  Dict[str, int] (offsets only).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Remove the standalone dist_util.gather_float_scalar. maybe_save now takes this
rank's data_timestamp and reconciles it across workers internally
(_reconcile_event_time: all_gather_object + quorum + sentinel handling).
set_save_policy creates the dedicated CPU (gloo) group once and the manager owns
it (self._ts_group), so the training loop no longer creates or threads a process
group through the save calls -- it just passes its local data_timestamp.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji merged commit 7886e4c into alibaba:master Jun 9, 2026
7 checks passed
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 9, 2026
Resolve main.py tail-checkpoint conflict against alibaba#541's CheckpointManager
refactor: the final save now goes through ckpt_manager.maybe_save(final=True),
and our on_train_end "is_ckpt_after_train" need is threaded in via a new
`force=` arg on maybe_save (bypasses the per-step dedupe so the FAISS codebook
fitted in on_train_end is persisted even when the last in-loop save hit the
final step). Also pulls in alibaba#541 (Kafka event-time checkpointing). Bump 1.2.19.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
WhiteSwan1 added a commit to WhiteSwan1/TorchEasyRec that referenced this pull request Jun 11, 2026
…_abstract

Brings the reviewed alibaba#539 foundation onto feat/sid_abstract (which already
carries alibaba#538 + an older RQ-VAE/RQ-Kmeans port), and syncs to upstream/master
(alibaba#540, alibaba#541, which alibaba#539 already contains).

Conflict resolutions:
- sid_rqkmeans.py(+test), residual_kmeans_quantizer.py, sid_model.py:
  take alibaba#539's canonical versions (BaseSidModel now hosts both SID models,
  with mse/rel_loss/unique_sid_ratio and the unified x_hat recon key).
- types.py: union — keep alibaba#539's QuantizeOutput, retain feat's
  QuantizeForwardMode enum + ResidualQuantizerOutput (RQ-VAE needs them).
- protos/models/sid_model.proto: union — alibaba#539's typed FaissKmeansConfig +
  clean SidRqkmeans, re-add feat's SinkhornConfig/ClipConfig/SidRqvae;
  drop the now-unused struct.proto import.
- protos/model.proto: enable `SidRqvae sid_rqvae = 600;` (the field alibaba#539
  reserved for this follow-up).
- main.py / model.py on_train_end: take alibaba#539's wording; drop feat's forced
  tail-checkpoint (SID models rely on the final=True tail save).

Transitional state: old modules/sid/kmeans.py still coexists with alibaba#539's
kmeans_quantize.py, and the RQ-VAE stack is still on the old abstraction —
both retired in the follow-up refactor commit. All SID modules import.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants