Skip to content

Commit 6ce0ad1

Browse files
committed
refine local_steps update logic
1 parent adb12e8 commit 6ce0ad1

4 files changed

Lines changed: 9 additions & 18 deletions

File tree

xtuner/v1/datasets/dataloader.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict:
5757
dataloader_state = get_dataloader_state(self, consumed_samples)
5858
return cast(dict, dataloader_state)
5959

60-
def record_consumed_samples(self, n: int) -> None:
61-
if hasattr(self.sampler, "record_consumed_samples"):
62-
self.sampler.record_consumed_samples(n)
63-
6460
def get_total_consumed_samples(self) -> int:
6561
sampler = self.sampler
6662
if hasattr(sampler, "get_total_consumed_steps"):

xtuner/v1/datasets/preset_sampler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def __init__(
163163
def __iter__(self) -> Iterator[int]:
164164
# load order from npy → global_order → rank_view 类型均为 memmap, 子视图 的路径仍然保持
165165
# memmap 语义(视图、按需分页、文件后端);单机多进程可共享同一份文件页缓存
166-
yield from self.global_order[self.step + self.rank : self.total_size : self.world_size]
166+
for idx in self.global_order[self.step + self.rank : self.total_size : self.world_size]:
167+
self._consumed.record(1)
168+
yield int(idx)
167169
self.step = 0
168170

169171
def __len__(self) -> int:
@@ -172,9 +174,6 @@ def __len__(self) -> int:
172174
def set_epoch(self, epoch: int) -> None:
173175
self.epoch = epoch
174176

175-
def record_consumed_samples(self, n: int) -> None:
176-
self._consumed.record(n)
177-
178177
def get_total_consumed_steps(self) -> int:
179178
return self._consumed.total_for_checkpoint()
180179

xtuner/v1/datasets/sampler.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def __iter__(self) -> Iterator[int]:
113113
# subsample
114114
indices = indices[self.step + self.rank : self.total_size : self.world_size]
115115

116-
yield from iter(indices)
116+
for idx in indices:
117+
self._consumed.record(1)
118+
yield idx
117119
self.step = 0
118120

119121
def __len__(self) -> int:
@@ -133,9 +135,6 @@ def set_epoch(self, epoch: int) -> None:
133135
"""
134136
self.epoch = epoch
135137

136-
def record_consumed_samples(self, n: int) -> None:
137-
self._consumed.record(n)
138-
139138
def get_total_consumed_steps(self) -> int:
140139
return self._consumed.total_for_checkpoint()
141140

@@ -275,7 +274,9 @@ def __iter__(self) -> Iterator[int]:
275274
assert len(indices) == self.total_size
276275
indices = indices[self.step + self.rank : self.total_size : self.world_size]
277276
assert len(indices) == self.num_samples - self.step // self.world_size
278-
yield from iter(indices)
277+
for idx in indices:
278+
self._consumed.record(1)
279+
yield idx
279280
self.step = 0
280281

281282
def __len__(self) -> int:
@@ -294,9 +295,6 @@ def set_epoch(self, epoch: int) -> None:
294295
"""
295296
self.epoch = epoch
296297

297-
def record_consumed_samples(self, n: int) -> None:
298-
self._consumed.record(n)
299-
300298
def get_total_consumed_steps(self) -> int:
301299
return self._consumed.total_for_checkpoint()
302300

xtuner/v1/train/trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,6 @@ def fit(self):
727727
train_begin = time.time()
728728
time_before_get_data = time.time()
729729
for data_batch in self._data_iter():
730-
consumed_samples = len(data_batch)
731730
time_before_train_step = time.time()
732731

733732
ProberList.set_step(self._cur_step + 1)
@@ -762,7 +761,6 @@ def fit(self):
762761
self._cur_step += 1
763762
step_tokens = train_step_info["step_consumed_tokens"]
764763
self._local_total_consumed_tokens += step_tokens
765-
self._dataloader.record_consumed_samples(consumed_samples)
766764
self._train_time = time_after_train_step - train_begin
767765

768766
# Compute training metrics

0 commit comments

Comments
 (0)