Skip to content

Commit eef30fd

Browse files
authored
Remove dep on torchtune for weight conversion (#17515)
### Summary Remove dep on torchtune for weight conversion. After this, I think torchtune is only used for model definitions in - phi-3-mini-lora - llama3_2_vision (these can't be removed) And a few other ckpt conversions via FullModelHFCheckpointer (this can be removed, in next pr) ### Test plan CI
1 parent 5a9b280 commit eef30fd

15 files changed

Lines changed: 55 additions & 23 deletions

File tree

examples/models/checkpoint.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import json
1111
import os
12+
import re
1213
from pathlib import Path
1314
from typing import Any, Dict, Optional
1415

@@ -112,3 +113,30 @@ def load_checkpoint_from_pytorch_model(input_dir: str) -> Dict:
112113
return state_dict
113114

114115
raise FileNotFoundError(f"Could not find pytorch_model checkpoint in {input_dir}")
116+
117+
118+
def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
119+
"""Map a state dict key using a mapping dictionary with "{}" layer number placeholders."""
120+
try:
121+
# Checks if there is a layer # in the key
122+
if any(k.isdigit() for k in key.split(".")):
123+
# Replace layer number with "{}" to create key for lookup
124+
abstract_key = re.sub(r"(\.\d+)", ".{}", key)
125+
match = re.search(r"\.(\d+)", key)
126+
if match is None:
127+
raise Exception(
128+
f'Error converting the state dict. Could not find layer number in key: "{key}". '
129+
"Please make sure you're loading a checkpoint with the right format. "
130+
)
131+
layer_num = match.group(1)
132+
new_key = mapping_dict[abstract_key]
133+
new_key = new_key.format(layer_num)
134+
else:
135+
new_key = mapping_dict[key]
136+
except KeyError as e:
137+
raise Exception(
138+
f'Error converting the state dict. Found unexpected key: "{key}". '
139+
"Please make sure you're loading a checkpoint with the right format. "
140+
) from e
141+
142+
return new_key

examples/models/codegen/convert_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import torch
66

7-
from torchtune.models.convert_weights import get_mapped_key
7+
from executorch.examples.models.checkpoint import get_mapped_key
88

99
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
1010
_HF__CODEGEN_2_FROM_META = {

examples/models/gemma/convert_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Dict
66

77
import torch
8-
from safetensors.torch import load_file
98

10-
from torchtune.models.convert_weights import get_mapped_key
9+
from executorch.examples.models.checkpoint import get_mapped_key
10+
from safetensors.torch import load_file
1111

1212

1313
# Weight mappings from Gemma's checkpoint to ExecuTorch's transformer parameters.

examples/models/gemma2/convert_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from typing import Dict
1111

1212
import torch
13-
from safetensors.torch import load_file
1413

15-
from torchtune.models.convert_weights import get_mapped_key
14+
from executorch.examples.models.checkpoint import get_mapped_key
15+
from safetensors.torch import load_file
1616

1717

1818
# Weight mappings from Gemma 2's checkpoint to ExecuTorch's transformer parameters.

examples/models/gemma3/convert_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Dict
66

77
import torch
8-
from safetensors.torch import load_file
98

10-
from torchtune.models.convert_weights import get_mapped_key
9+
from executorch.examples.models.checkpoint import get_mapped_key
10+
from safetensors.torch import load_file
1111

1212

1313
# Weight mappings from Gemma 3's checkpoint to ExecuTorch's transformer parameters.

examples/models/glm/convert_weights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Dict
44

55
import torch
6+
from executorch.examples.models.checkpoint import get_mapped_key
67
from safetensors.torch import load_file
7-
from torchtune.models.convert_weights import get_mapped_key
88

99
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
1010
_GLM_FROM_META = {

examples/models/granite/convert_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Dict
66

77
import torch
8-
from safetensors.torch import load_file
98

10-
from torchtune.models.convert_weights import get_mapped_key
9+
from executorch.examples.models.checkpoint import get_mapped_key
10+
from safetensors.torch import load_file
1111

1212

1313
# Weight mappings from Granite 3's checkpoint to ExecuTorch's transformer parameters.

examples/models/internvl3/convert_weights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from typing import Dict
33

44
import torch
5+
from executorch.examples.models.checkpoint import get_mapped_key
56

67
from executorch.examples.models.smollm3.convert_weights import load_checkpoint
7-
from torchtune.models.convert_weights import get_mapped_key
88

99
# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
1010
_INTERNVL_TO_META = {

examples/models/lfm2/convert_weights.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from typing import Dict
44

55
import torch
6-
from safetensors.torch import load_file
76

8-
from torchtune.models.convert_weights import get_mapped_key
7+
from executorch.examples.models.checkpoint import get_mapped_key
8+
from safetensors.torch import load_file
99

1010
_LFM_2_TO_META = {
1111
"model.embed_tokens.weight": "tok_embeddings.weight",

examples/models/llama/convert_weights.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Dict
22

33
import torch
4+
from executorch.examples.models.checkpoint import get_mapped_key
45

56
from safetensors.torch import load_file
6-
from torchtune.models.convert_weights import get_mapped_key
77

88
_UNSLOTH_TO_META = {
99
"base_model.model.model.layers.{}.mlp.down_proj.lora_A.weight": "layers.{}.feed_forward.w2.lora_a.weight",

0 commit comments

Comments
 (0)