Skip to content

Commit c590028

Browse files
committed
Sync modeling_mixtral_te.py fix to models/ source and register copy mapping
- Apply FSDP2 DTensor fix to bionemo-recipes/models/mixtral/modeling_mixtral_te.py (source) - Add mixtral modeling file to check_copied_files SOURCE_TO_DESTINATION_MAP - Recipe file now gets copied-file banner via check_copied_files --fix Signed-off-by: svc-bionemo <267129667+svc-bionemo@users.noreply.github.com>
1 parent 23f7df6 commit c590028

2 files changed

Lines changed: 24 additions & 5 deletions

File tree

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
279279
device = torch.cuda.current_device()
280280
for attr_name in ("experts_gate_up_weight", "experts_down_weight"):
281281
old_param = getattr(self, attr_name)
282-
new_data = torch.empty_like(old_param, device=device)
283-
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
284-
setattr(self, attr_name, nn.Parameter(new_data))
282+
if isinstance(old_param.data, DTensor):
283+
# FSDP2 has sharded this param; materialize the local shard on CUDA
284+
# and reconstruct the DTensor wrapper so FSDP2 can manage it.
285+
local_data = old_param.data.to_local()
286+
new_local = torch.empty(local_data.shape, dtype=local_data.dtype, device=device)
287+
torch.nn.init.normal_(new_local, mean=0.0, std=self.initializer_range)
288+
new_dtensor = DTensor.from_local(
289+
new_local,
290+
device_mesh=old_param.data.device_mesh,
291+
placements=old_param.data.placements,
292+
)
293+
setattr(self, attr_name, nn.Parameter(new_dtensor))
294+
else:
295+
new_data = torch.empty_like(old_param, device=device)
296+
torch.nn.init.normal_(new_data, mean=0.0, std=self.initializer_range)
297+
setattr(self, attr_name, nn.Parameter(new_data))
285298

286299
# Re-sync views to point to the new stacked parameter
287300
self._sync_expert_views()
@@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
298311
gate_up_w = self.experts_gate_up_weight
299312
if isinstance(gate_up_w, DTensor):
300313
gate_up_w = gate_up_w.to_local()
301-
for i in range(self.num_local_experts):
314+
num_local = gate_up_w.shape[0]
315+
for i in range(num_local):
302316
object.__setattr__(self.experts_gate_up, f"weight{i}", gate_up_w[i])
303317

304318
down_w = self.experts_down_weight
305319
if isinstance(down_w, DTensor):
306320
down_w = down_w.to_local()
307-
for i in range(self.num_local_experts):
321+
num_local_down = down_w.shape[0]
322+
for i in range(num_local_down):
308323
object.__setattr__(self.experts_down, f"weight{i}", down_w[i])
309324

310325
def set_ep_group(self, ep_group: dist.ProcessGroup, ep_mesh: DeviceMesh) -> None:

ci/scripts/check_copied_files.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ def _compare_file_contents(source_file: Path, dest_file: Path, source_display: s
205205
"bionemo-recipes/models/codonfm/modeling_codonfm_te.py": [
206206
"bionemo-recipes/recipes/codonfm_native_te/modeling_codonfm_te.py",
207207
],
208+
# Mixtral TE model -> recipe sync
209+
"bionemo-recipes/models/mixtral/modeling_mixtral_te.py": [
210+
"bionemo-recipes/recipes/mixtral_native_te/modeling_mixtral_te.py",
211+
],
208212
# Common test library - synced between models
209213
"bionemo-recipes/models/esm2/tests/common": [
210214
"bionemo-recipes/models/llama3/tests/common",

0 commit comments

Comments
 (0)