-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathcheckpoint.py
More file actions
302 lines (251 loc) · 11.3 KB
/
checkpoint.py
File metadata and controls
302 lines (251 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Checkpoint resharding utilities.
Disk format (canonical):
- No module.{N}. prefix
- Experts individually indexed: layers.1.mlp.experts.3.gate_proj.weight
Runtime format (local):
- DualPipeV prefix: module.0.layers.1.mlp.experts.gate_proj.weight
- Experts stacked per EP rank: shape [experts_per_rank, ...]
"""
import re
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
__all__ = ["to_canonical_model", "to_canonical_optim", "to_localized_model", "to_localized_optim"]
MODULE_PREFIX_RE = re.compile(r"^module\.\d+\.")
INDEXED_EXPERT_RE = re.compile(r"(.*)\.experts\.(\d+)\.(.*)")
def strip_prefix(key: str) -> str:
"""Strip module.{N}. prefix from a DualPipeV FQN."""
return MODULE_PREFIX_RE.sub("", key)
def find_moe(key: str, named_modules: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""
Return the MoE module for a stacked expert key, or None.
A stacked key looks like module.0.layers.1.mlp.experts.gate_proj.weight
(no numeric index after .experts.). Already-indexed keys like
layers.1.mlp.experts.3.gate_proj.weight return None.
"""
if ".experts." not in key:
return None
moe_path, _, after = key.partition(".experts.")
if after and after.split(".")[0].isdigit():
return None
mod = named_modules.get(moe_path)
if mod and hasattr(mod, "ep_rank") and hasattr(mod, "experts_per_rank"):
return mod
return None
def expert_range(mod: nn.Module) -> Tuple[int, int]:
"""Global expert index range [start, end) for this EP rank."""
start = mod.ep_rank * mod.experts_per_rank
return start, start + mod.experts_per_rank
def unwrap_dtensor_experts(value: Any, expected_n: int) -> Optional[Tuple[Any, int, int]]:
"""
Extract local expert data from a DTensor without triggering all_gather.
When FSDP shards the stacked expert tensor along dim 0 via Shard(0),
each DP rank holds a contiguous subset of experts. This helper extracts
that local subset and computes the global expert offset so that unpack
can emit per-expert keys for only the experts this rank actually owns,
with zero GPU communication.
Works for both model tensors (a single DTensor) and optimizer state
entries (a dict whose tensor values are DTensors).
Returns (localized_value, local_expert_count, dp_expert_offset)
or None if value is not a sharded DTensor.
"""
def _info(dt: Any) -> Optional[Tuple[torch.Tensor, int, int]]:
"""Return (local_tensor, local_n, dp_offset) for a Shard(0) DTensor, or None."""
if not isinstance(dt, DTensor) or dt.dim() == 0 or dt.shape[0] != expected_n:
return None
if not dt.placements:
return None
if not isinstance(dt.placements[0], Shard):
return None
if dt.placements[0].dim != 0:
return None
local = dt._local_tensor
if local.shape[0] >= expected_n:
return None
dp_rank = dt.device_mesh.get_local_rank()
dp_size = dt.device_mesh.size()
chunk, remainder = divmod(expected_n, dp_size)
dp_offset = dp_rank * chunk + min(dp_rank, remainder)
return local, local.shape[0], dp_offset
if isinstance(value, DTensor):
return _info(value)
if isinstance(value, dict):
ref = None
for v in value.values():
ref = _info(v) if isinstance(v, DTensor) else None
if ref is not None:
break
if ref is None:
return None
_, local_n, dp_offset = ref
localized = {k: v._local_tensor if isinstance(v, DTensor) else v for k, v in value.items()}
return localized, local_n, dp_offset
return None
def unpack(
entries: Dict[str, Any],
named_modules: Dict[str, nn.Module],
unstack: Callable[[Any, int, int], Any],
) -> Dict[str, Any]:
"""
Strip module prefix and unpack stacked experts into individual entries.
unstack(value, num_local_experts, local_idx) extracts one expert slice
from a stacked value. For model tensors this is v[i]; for optimizer
state dicts it slices each sub-tensor.
When values are FSDP-sharded DTensors the expert dimension is extracted
locally (via unwrap_dtensor_experts) so that each DP rank emits keys
only for the experts it owns -- no GPU all_gather is triggered.
"""
result: Dict[str, Any] = {}
for key, value in entries.items():
canon = strip_prefix(key)
moe = find_moe(key, named_modules)
if moe is None:
result[canon] = value
continue
start, _ = expert_range(moe)
n = moe.experts_per_rank
local_info = unwrap_dtensor_experts(value, n)
if local_info is not None:
local_value, local_n, dp_offset = local_info
for i in range(local_n):
global_idx = start + dp_offset + i
ckey = canon.replace(".experts.", ".experts.%d." % global_idx, 1)
result[ckey] = unstack(local_value, local_n, i)
else:
for i in range(n):
ckey = canon.replace(".experts.", ".experts.%d." % (start + i), 1)
result[ckey] = unstack(value, n, i)
return result
def repack(
entries: Dict[str, Any],
fqn_map: Dict[str, str],
named_modules: Dict[str, nn.Module],
restack: Callable[[Dict[int, Any]], Any],
) -> Dict[str, Any]:
"""
Remap canonical FQNs to local FQNs and repack individual experts into stacked format.
fqn_map maps {canonical_fqn: local_fqn}.
restack(by_global_idx) stacks individual expert values back into one.
"""
result: Dict[str, Any] = {}
to_stack: Dict[str, Dict[int, Any]] = {}
for canon_fqn, value in entries.items():
m = INDEXED_EXPERT_RE.match(canon_fqn)
if m:
prefix, idx_str, suffix = m.groups()
stacked_canon = "%s.experts.%s" % (prefix, suffix)
local = fqn_map.get(stacked_canon)
if local is not None:
moe_path = local.partition(".experts.")[0]
moe = named_modules.get(moe_path)
if moe and hasattr(moe, "ep_rank"):
s, e = expert_range(moe)
idx = int(idx_str)
if s <= idx < e:
to_stack.setdefault(local, {})[idx] = value
else:
local = fqn_map.get(canon_fqn)
if local is not None:
result[local] = value
for local, by_idx in to_stack.items():
result[local] = restack(by_idx)
return result
def unstack_optim(entry: Dict[str, Any], n: int, i: int) -> Dict[str, Any]:
"""Extract one expert slice from a stacked optimizer state entry."""
return {
k: v[i] if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape[0] == n else v
for k, v in entry.items()
}
def restack_tensors(by_idx: Dict[int, torch.Tensor]) -> torch.Tensor:
"""Stack individual expert tensors back into one."""
return torch.stack([v for _, v in sorted(by_idx.items())])
def restack_optim(by_idx: Dict[int, Dict[str, Any]]) -> Dict[str, Any]:
"""Stack individual expert optimizer state entries back into one."""
items = sorted(by_idx.items())
sample = items[0][1]
return {
k: torch.stack([by_idx[i][k] for i, _ in items])
if isinstance(sample[k], torch.Tensor) and sample[k].dim() > 0
else sample[k]
for k in sample
}
def to_canonical_model(
state_dict: Dict[str, torch.Tensor], model: nn.Module
) -> Dict[str, torch.Tensor]:
"""Canonicalize model state: strip module prefix, unstack experts."""
return unpack(state_dict, dict(model.named_modules()), lambda v, n, i: v[i])
def to_canonical_optim(optim_state: Dict, model: nn.Module) -> Dict:
"""Canonicalize optimizer state: strip module prefix, unstack expert states."""
state = unpack(optim_state["state"], dict(model.named_modules()), unstack_optim)
params = list(state.keys())
param_groups = []
for g in optim_state["param_groups"]:
group = {k: v for k, v in g.items() if k != "params"}
group["params"] = params
param_groups.append(group)
return {"state": state, "param_groups": param_groups}
def rewrap_dtensor_experts(result: Dict[str, Any], model: nn.Module) -> None:
"""
Re-wrap restacked plain tensors as DTensors to match model parameters.
When unpack extracts local FSDP shards via unwrap_dtensor_experts,
the individual expert tensors are plain (non-DTensor). After repack
stacks them the result is a plain tensor whose shape equals the FSDP local
shard (e.g. [16, ...] when the full stacked parameter is [32, ...]).
model.load_state_dict compares against the DTensor global shape, so
we must wrap these plain tensors back into DTensors with the same mesh and
placements as the model parameter. For optimizer state entries (dicts of
tensors), every tensor with dim() > 0 is wrapped.
"""
param_dtensors = {n: p for n, p in model.named_parameters() if isinstance(p, DTensor)}
for name, param in param_dtensors.items():
value = result.get(name)
if value is None:
continue
if isinstance(value, torch.Tensor) and not isinstance(value, DTensor):
result[name] = DTensor.from_local(
value,
device_mesh=param.device_mesh,
placements=param.placements,
run_check=False,
)
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, torch.Tensor) and not isinstance(v, DTensor) and v.dim() > 0:
value[k] = DTensor.from_local(
v,
device_mesh=param.device_mesh,
placements=param.placements,
run_check=False,
)
def to_localized_model(
canonical: Dict[str, torch.Tensor], model: nn.Module
) -> Dict[str, torch.Tensor]:
"""Localize model state: remap FQNs to this rank's local keys, restack experts."""
named_modules = dict(model.named_modules())
model_keys = set(model.state_dict().keys())
fqn_map = {strip_prefix(k): k for k in model_keys}
result = repack(canonical, fqn_map, named_modules, restack_tensors)
rewrap_dtensor_experts(result, model)
return result
def to_localized_optim(optim_state: Dict, model: nn.Module) -> Dict:
"""
Localize optimizer state: remap FQNs, restack experts, rebuild param_groups.
The param_groups are rebuilt with ALL of the current model's local FQNs
(ignoring the loaded params lists) because DCP deduplicates non-tensor
metadata across PP ranks. Hyperparameters (lr, betas, ...) are taken from
the loaded param_groups.
"""
named_modules = dict(model.named_modules())
fqn_map = {strip_prefix(n): n for n, _ in model.named_parameters()}
state = repack(optim_state["state"], fqn_map, named_modules, restack_optim)
rewrap_dtensor_experts(state, model)
all_local = list(fqn_map.values())
param_groups = []
for g in optim_state["param_groups"]:
group = {k: v for k, v in g.items() if k != "params"}
group["params"] = all_local
param_groups.append(group)
return {"state": state, "param_groups": param_groups}