Skip to content

Commit 8b12dcd

Browse files
committed
fix: use correct type for _current_global_parameter_metas
1 parent 490c222 commit 8b12dcd

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

checkpoint_engine/ps.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,12 @@ def gather_metas(self, checkpoint_name: str):
725725
if not self._global_device_uuids:
726726
global_device_uuids.append(metas_buckets.device_uuid)
727727
if metas_buckets.memory_buffer_metas_list:
728-
self._current_global_parameter_metas[i] = metas_buckets
728+
# _current_global_parameter_metas value should be MemoryBufferMetaList, but metas_buckets is DataToGather
729+
# so we need to convert it to MemoryBufferMetaList
730+
self._current_global_parameter_metas[i] = MemoryBufferMetaList(
731+
memory_buffer_metas_list=metas_buckets.memory_buffer_metas_list,
732+
p2p_store_addr=metas_buckets.p2p_store_addr,
733+
)
729734
num_parameters += sum(len(x.metas) for x in metas_buckets.memory_buffer_metas_list)
730735
if not self._all_hosts:
731736
self._all_hosts = all_hosts

0 commit comments

Comments
 (0)