[feat] Kafka: event-time driven checkpointing from message timestamp#541
Conversation
Surface the Kafka message timestamp on the batch (Batch.data_timestamp,
ms) and use it to trigger checkpoint saves based on consumed event-time:
- save_checkpoints_timestamp_interval: save every N seconds of consumed
event-time, aligned to epoch boundaries so cut points are deterministic
- save_checkpoints_timestamps: save once when consumed data crosses each
configured absolute timestamp
- checkpoint_timestamp_reduce: min/max cross-rank reconciliation of the
per-rank consumed event-time
The KafkaReader now reads msg.timestamp() alongside offset/partition and
carries it in a transient __data_timestamp__ column that _build_batch
collapses into Batch.data_timestamp. The timestamp is not persisted:
dataloader_state.json stays {source: offset}, so resume/restore and the
non-kafka (file/odps) datasets are unchanged.
The per-step cross-rank all-reduce that reconciles event-time is gated
only by the config-derived ts_trigger_active, so every rank participates
identically and the collective stays deadlock-safe; ranks without a
timestamp contribute a neutral sentinel.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| if dist.is_initialized(): | ||
| ts_tensor = torch.tensor( | ||
| [local_ts_s], dtype=torch.float64, device=ts_device | ||
| ) | ||
| dist.all_reduce(ts_tensor, op=reduce_op) | ||
| global_ts_s = ts_tensor.item() |
There was a problem hiding this comment.
When timestamp-checkpointing is active, this runs a synchronous all_reduce + .item() every training step. The .item() forces a host/GPU sync that isn't overlapped with compute, and unlike the gradient all-reduce it can't piggyback on the backward pass — so it partially defeats TrainPipelineSparseDist's prefetch/overlap. Checkpoint decisions only need coarse event-time resolution; consider reducing every N steps (with N derived identically on all ranks so the collective stays in lockstep) or using async_op=True and reading the result one step later.
Separately, this per-step collective assumes all ranks run the same number of steps. That invariant is already required by the model's fwd/bwd collectives, but it's worth noting that with Kafka's uneven partition data and check_all_workers_data_status=False (i.e. no batch_cost_size), a rank whose partitions drain first hits StopIteration and breaks before this all_reduce, hanging the others. The "every rank participates identically" comment holds only under that equal-step-count invariant — worth documenting that this feature needs batch_cost_size/even data, or gating the all-reduce on check_all_workers_data_status.
| if save_checkpoints_epochs > 0 and i_step > 0: | ||
| if (i_epoch + 1) % save_checkpoints_epochs == 0: | ||
| last_ckpt_step = i_step | ||
| ckpt_manager.save(i_step, model, optimizer, dataloader_state) | ||
| if eval_dataloader is not None: | ||
| _evaluate( | ||
| model, | ||
| eval_dataloader, | ||
| eval_config, | ||
| eval_result_filename=eval_result_filename, | ||
| global_step=i_step, | ||
| eval_summary_writer=eval_summary_writer, | ||
| global_epoch=i_epoch, | ||
| check_all_workers_data_status=check_all_workers_data_status, | ||
| ) | ||
| model.train() | ||
| do_checkpoint(i_step) |
There was a problem hiding this comment.
The step-trigger and timestamp-trigger both guard against re-saving a step already saved (i_step != last_ckpt_step), and so does the final save below — but the epoch-trigger doesn't. If a timestamp-triggered save fired on the last step of the epoch (so last_ckpt_step == i_step), this re-saves the same step: a redundant full checkpoint write and a redundant eval.
| if save_checkpoints_epochs > 0 and i_step > 0: | |
| if (i_epoch + 1) % save_checkpoints_epochs == 0: | |
| last_ckpt_step = i_step | |
| ckpt_manager.save(i_step, model, optimizer, dataloader_state) | |
| if eval_dataloader is not None: | |
| _evaluate( | |
| model, | |
| eval_dataloader, | |
| eval_config, | |
| eval_result_filename=eval_result_filename, | |
| global_step=i_step, | |
| eval_summary_writer=eval_summary_writer, | |
| global_epoch=i_epoch, | |
| check_all_workers_data_status=check_all_workers_data_status, | |
| ) | |
| model.train() | |
| do_checkpoint(i_step) | |
| if save_checkpoints_epochs > 0 and i_step > 0 and i_step != last_ckpt_step: | |
| if (i_epoch + 1) % save_checkpoints_epochs == 0: | |
| do_checkpoint(i_step) |
Alternatively, centralize the guard inside do_checkpoint (early-return when step == last_ckpt_step) so no call site can forget it.
| // event-time interval (seconds) for saving checkpoint, based on the consumed | ||
| // data timestamp (e.g. kafka message timestamp). saves once per epoch-aligned | ||
| // interval boundary that the consumed event-time crosses. 0 disables. |
There was a problem hiding this comment.
"epoch-aligned" is misleading here — alongside save_checkpoints_epochs/num_epochs it reads as "aligned to a training epoch", but the code is floor(ts / interval_s), i.e. boundaries are multiples of interval_s since the Unix epoch (wall-clock), unrelated to training epochs. Same wording appears in checkpoint_util.py (should_save_on_timestamp docstring/comments) and train.md ("按epoch对齐") — worth fixing in all three.
| // event-time interval (seconds) for saving checkpoint, based on the consumed | |
| // data timestamp (e.g. kafka message timestamp). saves once per epoch-aligned | |
| // interval boundary that the consumed event-time crosses. 0 disables. | |
| // event-time interval (seconds) for saving checkpoint, based on the consumed | |
| // data timestamp (e.g. kafka message timestamp). saves once per interval | |
| // boundary (aligned to the Unix epoch, not training epochs) crossed. 0 disables. |
Review summaryClean, well-documented feature. The I left inline comments on three points; the substantive ones:
A couple of additional notes not worth inlining:
|
…uorum Address PR review on event-time checkpointing. Centralize save-cadence in CheckpointManager.maybe_save(): - step / epoch / event-time triggers + a single dedupe live in one place, so no save site can re-save a step another already saved (fixes the epoch double-save at its root, and makes the decision logic unit-testable). - the event-time watermark gets its natural home alongside the manager's dataloader_state IO. Replace the min/max cross-rank reduce with a worker quorum: - save_checkpoints_timestamp_quorum (float, default 0.5): save once >= quorum fraction of data-carrying workers have consumed past the boundary/target, computed as the (1-quorum) upper quantile of per-worker event-times. - generalizes the old knobs (quorum 1.0 == min, near 0 == max) and is robust to a single outlier/garbage timestamp at the default, so no wall-clock guard is needed. Make resume correct: persist the event-time watermark in dataloader_state and seed it on continue_train, so resume re-fires no already-saved boundary and misses none. The cross-rank reconciliation uses all_gather (not all_reduce) so a NaN sentinel for a rank without a timestamp never poisons the others; the loop does the gather, the manager does the quorum + decision (no collective when not saving). Clarify wording: the interval is aligned to the Unix epoch (wall-clock), not to training epochs. Add unit tests for quorum_event_time, maybe_save (incl. the epoch-dedupe regression), Batch.data_timestamp extraction in _build_batch, and Batch.to()/pin_memory() propagation. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
Thanks for the review — addressed in Epoch-trigger double-save (real bug) — fixed at the root. Rather than adding the missing guard to one more site, the save-cadence logic (step + epoch + event-time) is now centralized in Per-step "epoch-aligned" wording. Fixed in Test gaps. Added: "Uneven partitions → deadlock". Not a new deadlock: a rank that drains first hangs at the model's own gradient all-reduce inside New: resume correctness. The event-time watermark is now persisted in |
Make the event-time feature single-unit. Batch.data_timestamp is now Unix-epoch seconds (float), matching the config (interval/targets in seconds), the persisted watermark, the trigger math, and Python time.time(). Normalization from the raw kafka ms happens once in _build_batch (max_ts / 1000.0); the internal __data_timestamp__ arrow column keeps the raw ms. The all_gather no longer needs a /1000 conversion. float64 preserves full ms precision at epoch-seconds magnitude. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…d_batch The ms->seconds conversion is Kafka-specific, so move it out of the generic BaseDataset._build_batch and into KafkaReader where the ms timestamp originates. The __data_timestamp__ column is now float64 Unix-epoch seconds (-1 sentinel for unavailable); _build_batch just surfaces the per-batch max valid value and stays unit-agnostic (the producing reader owns the unit). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Collapse the repetitive single-assertion tests into @parameterized.expand rows
(param("name", **kwargs)): should_save_on_timestamp, quorum_event_time, the
_build_batch data_timestamp extraction, and Batch copy (to/pin_memory)
propagation. The stateful maybe_save sequence tests stay as individual cases.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Make data_timestamp a plain float defaulting to -1.0 ("unavailable") rather
than Optional[float]/None. The -1.0 sentinel is now uniform end-to-end: kafka's
own -1 -> __data_timestamp__ column -> Batch.data_timestamp -> filtered by the
>= 0 check in the gather. This drops the None/NaN special-casing in
_gather_worker_event_times (now a simple >= 0 filter) and removes the unused
math import in main.py.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ments - Extract the per-worker event-time all-gather into a generic dist_util.gather_float_scalar(value, device); it no longer filters the -1.0 sentinel. The downstream quorum/save logic handles -1.0: a worker without a timestamp sorts low (counts as "not past"), so the quorum is over all workers, and maybe_save treats a negative quorum result as "no event-time this step". - Add Args/Returns docstrings to the public CheckpointManager methods (maybe_save, set_save_policy) and gather_float_scalar; tighten the inline comments added in this PR. - Drop the now-redundant >= 0 check in _build_batch (the -1.0 sentinel sorts below any real time, so max already yields the right value). - Remove the Batch.to()/pin_memory() propagation test. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
| worker_ts = ( | ||
| gather_float_scalar(batch.data_timestamp, ts_device) | ||
| if need_ts | ||
| else None | ||
| ) |
There was a problem hiding this comment.
When timestamp checkpointing is enabled, this runs every step: gather_float_scalar builds a CUDA float64 tensor, does an all_gather_into_tensor, then .tolist() — a blocking D2H sync. The baseline loop has no per-step host sync (train metrics only .item() every log_step_count_steps), so this new sync caps how far TrainPipelineSparseDist can run ahead to prefetch/input_dist the next batch and can cost throughput — exactly in the streaming-Kafka scenario where this feature is on.
Since saves are rare relative to steps, consider gathering/deciding only every N steps (keep it a deterministic function of i_step so it stays in lockstep across ranks), or running the gather on a gloo/CPU PG so the .tolist() doesn't synchronize the compute stream. No impact when the feature is disabled.
| if (ts_interval_s > 0 or len(ts_targets) > 0) and not (0.0 < ts_quorum <= 1.0): | ||
| raise ValueError( | ||
| f"save_checkpoints_timestamp_quorum must be in (0, 1], got {ts_quorum}." | ||
| ) |
There was a problem hiding this comment.
This quorum-range validation (and its ValueError) has no test, and gather_float_scalar — the new distributed primitive that feeds the trigger — is entirely untested (there's no dist_util_test.py). Both are cheap and high-value:
assertRaises(ValueError)forquorum<=0/quorum>1with a trigger active, plus a no-raise case when both triggers are off (the conjunction here is exactly what silently breaks on refactor).- At least the non-distributed
gather_float_scalarpath (return [value]); the gloomp.spawnpattern already incheckpoint_util_test.pycan cover the distributed rank-order path.
| # advance + persist the watermark on every save so resume is exact | ||
| self._last_data_ts = data_ts | ||
| if dataloader_state is not None: | ||
| dataloader_state[DATA_TS_WATERMARK] = data_ts |
There was a problem hiding this comment.
Minor: dataloader_state is annotated Dict[str, int] throughout (save, maybe_save, save_dataloader_state, update_dataloder_state), but this stores a float watermark. Functionally harmless — the cross-rank max-merge in save_dataloader_state is a no-op since the value is rank-identical — but in this pyre-checked codebase the annotation is now inaccurate. Consider widening to Dict[str, Union[int, float]].
| - save_checkpoints_timestamp_interval: 按数据时间(如Kafka消息的timestamp)保存模型的间隔秒数,每当已消费数据的事件时间跨过一个间隔边界时保存一次;边界按 Unix 时间(epoch/墙钟)对齐,而非训练 epoch。默认0表示关闭。仅对带timestamp的数据源(如KafkaDataset)生效 | ||
| - save_checkpoints_timestamps: 按数据时间保存模型的绝对时间点列表(单位为秒的 Unix 时间戳),已消费数据的事件时间每跨过一个时间点保存一次,默认为空表示关闭 | ||
| - save_checkpoints_timestamp_quorum: 分布式训练下各rank消费的分区不同,事件时间也不同;当至少该比例(取值(0,1\],默认0.5)的rank已消费越过边界/时间点时才触发保存。1.0表示所有rank都越过才保存(最保守),越小越激进(最小可仅需一个rank);默认值对单个异常/超前的timestamp具有鲁棒性 |
There was a problem hiding this comment.
The save_checkpoints_steps/epochs bullets above note 保存模型后会做一次评估, but these three timestamp bullets don't. Timestamp-triggered saves also run an eval (run_eval fires whenever maybe_save returns True), so worth adding the same note for parity.
Review summaryReviewed across code quality, performance, tests, docs, and distributed safety. This is a well-built PR — the centralization of all save cadence into Noteworthy items, posted inline:
One optional hardening (not blocking): the watermark read from Nothing here blocks merge. |
…e type - gather_float_scalar now all-gathers via all_gather_object over a passed process group instead of an all_gather_into_tensor + .tolist() on the model device. The train loop creates a dedicated gloo group for it, so the per-step event-time reconciliation stays on CPU and never forces a GPU device sync that would stall the TrainPipelineSparseDist prefetch in the streaming case. - Widen the dataloader_state annotations to Dict[str, Any] (save / maybe_save / save_dataloader_state / update_dataloder_state / restore / load_state_dict), since it now holds the float event-time watermark alongside the int offsets; the old Dict[str, int] was inaccurate under pyre. checkpoint_info stays Dict[str, int] (offsets only). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Remove the standalone dist_util.gather_float_scalar. maybe_save now takes this rank's data_timestamp and reconciles it across workers internally (_reconcile_event_time: all_gather_object + quorum + sentinel handling). set_save_policy creates the dedicated CPU (gloo) group once and the manager owns it (self._ts_group), so the training loop no longer creates or threads a process group through the save calls -- it just passes its local data_timestamp. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Resolve main.py tail-checkpoint conflict against alibaba#541's CheckpointManager refactor: the final save now goes through ckpt_manager.maybe_save(final=True), and our on_train_end "is_ckpt_after_train" need is threaded in via a new `force=` arg on maybe_save (bypasses the per-step dedupe so the FAISS codebook fitted in on_train_end is persisted even when the last in-loop save hit the final step). Also pulls in alibaba#541 (Kafka event-time checkpointing). Bump 1.2.19. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…_abstract Brings the reviewed alibaba#539 foundation onto feat/sid_abstract (which already carries alibaba#538 + an older RQ-VAE/RQ-Kmeans port), and syncs to upstream/master (alibaba#540, alibaba#541, which alibaba#539 already contains). Conflict resolutions: - sid_rqkmeans.py(+test), residual_kmeans_quantizer.py, sid_model.py: take alibaba#539's canonical versions (BaseSidModel now hosts both SID models, with mse/rel_loss/unique_sid_ratio and the unified x_hat recon key). - types.py: union — keep alibaba#539's QuantizeOutput, retain feat's QuantizeForwardMode enum + ResidualQuantizerOutput (RQ-VAE needs them). - protos/models/sid_model.proto: union — alibaba#539's typed FaissKmeansConfig + clean SidRqkmeans, re-add feat's SinkhornConfig/ClipConfig/SidRqvae; drop the now-unused struct.proto import. - protos/model.proto: enable `SidRqvae sid_rqvae = 600;` (the field alibaba#539 reserved for this follow-up). - main.py / model.py on_train_end: take alibaba#539's wording; drop feat's forced tail-checkpoint (SID models rely on the final=True tail save). Transitional state: old modules/sid/kmeans.py still coexists with alibaba#539's kmeans_quantize.py, and the RQ-VAE stack is still on the old abstraction — both retired in the follow-up refactor commit. All SID modules import.
Motivation
For online/streaming training from Kafka,
KafkaReaderpreviously surfaced only the message offset (as the resume cursor); the Kafka message timestamp was never read, so there was no way to checkpoint based on the event-time of the consumed data. This PR surfaces the message timestamp and uses it to drive checkpoint saves on consumed event-time, which is what you want for time-aligned snapshots of a stream.What it does
Surfaces the Kafka message timestamp on the batch and adds event-time–driven checkpoint triggers, alongside the existing step/epoch triggers:
batch.data_timestamp— the max consumed event-time in the batch, as Unix-epoch seconds (float;-1.0when the source has no timestamps).KafkaReadernormalizes its ms timestamp to seconds;BaseDataset._build_batchjust surfaces the per-batch max (unit-agnostic, so non-Kafka datasets are unaffected).TrainConfigfields:save_checkpoints_timestamp_interval— save every N seconds of consumed event-time, aligned to the Unix epoch (wall-clock), not training epochs.0disables.save_checkpoints_timestamps— absolute event-time targets (Unix-epoch seconds); save once when consumed data crosses each. Empty disables.save_checkpoints_timestamp_quorum— fraction of workers(0, 1], default0.5, that must have consumed past a boundary/target before a save fires.Design
CheckpointManager.maybe_save(...) -> bool; the training loop calls it at the step / epoch / final points and runs eval when it returnsTrue. This removes the four hand-guarded save sites in the loop (the source of an epoch-vs-timestamp double-save) and makes the decision logic unit-testable.maybe_save, which all-gathers across workers and reduces to the(1 - quorum)upper quantile — the event-time that ≥quorumfraction of workers have reached. This generalizes min (quorum=1.0) / max (quorum→0), is robust to a single outlier/garbage timestamp at the default0.5, and counts a worker without a timestamp (-1.0) as "not past" so the quorum is over all workers.CheckpointManagercreates inset_save_policyand owns, so it never runs on the model device or forces a per-step device sync that would stall theTrainPipelineSparseDistprefetch.dataloader_state(saved every checkpoint, restored oncontinue_train) and advanced on every save, so resume re-fires no already-saved boundary and misses none. It's a single global float, identical across ranks (the per-key max-merge is a no-op on it) and ignored by the existingdataloader_stateconsumers.dataloader_stateis typedDict[str, Any]since it now holds the float watermark alongside the int offsets.saveis entered in lockstep; nothing collective runs when not saving. It relies on the same equal-step-count invariant the model's own gradient collectives already require.Tests
quorum_event_time(quantile incl. outlier/sentinel cases) andshould_save_on_timestamp(interval/target crossings, no-refire, first-batch init) — parameterized.CheckpointManager.maybe_save— step/epoch/timestamp triggers, the dedupe (incl. the epoch-after-same-step regression), watermark stamping, restore seeding, and the-1.0-sentinel path; plus_reconcile_event_timesingle-process behavior (savemocked, no GPU/dist needed)._build_batchdata_timestampextraction; Kafka roundtrip asserts a positive seconds value.TrainConfigfields documented intrain.md.🤖 Generated with Claude Code