|
69 | 69 | import torch |
70 | 70 | import torch.distributed as dist |
71 | 71 | from megatron.bridge.data.samplers import build_pretraining_data_loader |
72 | | -from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint |
| 72 | +from megatron.bridge.training.checkpointing import ( |
| 73 | + _generate_model_state_dict, |
| 74 | + _load_model_weights_from_checkpoint, |
| 75 | + apply_peft_adapter_filter_to_state_dict, |
| 76 | +) |
73 | 77 | from megatron.bridge.training.config import DistributedInitConfig, RNGConfig |
74 | 78 | from megatron.bridge.training.mixed_precision import MIXED_PRECISION_RECIPES, get_mixed_precision_config |
75 | 79 | from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer |
|
86 | 90 | get_world_size_safe, |
87 | 91 | ) |
88 | 92 | from megatron.bridge.utils.instantiate_utils import instantiate |
89 | | -from megatron.core import parallel_state, tensor_parallel |
| 93 | +from megatron.core import dist_checkpointing, parallel_state, tensor_parallel |
90 | 94 | from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator |
91 | 95 | from megatron.core.tensor_parallel.mappings import _gather_along_last_dim |
92 | 96 | from megatron.core.transformer.module import Float16Module |
@@ -1117,12 +1121,36 @@ def predict( |
1117 | 1121 | else: |
1118 | 1122 | logger.warning("Could not determine number of layers from model structure") |
1119 | 1123 |
|
1120 | | - logger.info(f"Loading weights from: {resolved_ckpt_dir}") |
1121 | | - _load_model_weights_from_checkpoint( |
1122 | | - checkpoint_path=str(resolved_ckpt_dir), |
1123 | | - model=model, |
1124 | | - dist_ckpt_strictness="ignore_all", |
1125 | | - ) |
| 1124 | + peft_section = run_config.get("peft") |
| 1125 | + if peft_section is not None: |
| 1126 | + pretrained_ckpt = resolve_checkpoint_path(Path(run_config["checkpoint"]["pretrained_checkpoint"])) |
| 1127 | + logger.info(f"Loading base model weights from: {pretrained_ckpt}") |
| 1128 | + _load_model_weights_from_checkpoint( |
| 1129 | + checkpoint_path=str(pretrained_ckpt), |
| 1130 | + model=model, |
| 1131 | + dist_ckpt_strictness="ignore_all", |
| 1132 | + ) |
| 1133 | + |
| 1134 | + unwrapped = [m.module for m in model] |
| 1135 | + peft_cfg = instantiate(peft_section) |
| 1136 | + peft_cfg(unwrapped, training=False) |
| 1137 | + |
| 1138 | + logger.info(f"Loading adapter weights from: {resolved_ckpt_dir}") |
| 1139 | + sharded_sd = _generate_model_state_dict(unwrapped, {}) |
| 1140 | + sharded_sd = apply_peft_adapter_filter_to_state_dict(sharded_sd, peft_cfg) |
| 1141 | + loaded = dist_checkpointing.load(sharded_sd, str(resolved_ckpt_dir), strict="ignore_all") |
| 1142 | + if len(unwrapped) == 1: |
| 1143 | + unwrapped[0].load_state_dict(loaded["model"], strict=False) |
| 1144 | + else: |
| 1145 | + for i, inner in enumerate(unwrapped): |
| 1146 | + inner.load_state_dict(loaded[f"model{i}"], strict=False) |
| 1147 | + else: |
| 1148 | + logger.info(f"Loading weights from: {resolved_ckpt_dir}") |
| 1149 | + _load_model_weights_from_checkpoint( |
| 1150 | + checkpoint_path=str(resolved_ckpt_dir), |
| 1151 | + model=model, |
| 1152 | + dist_ckpt_strictness="ignore_all", |
| 1153 | + ) |
1126 | 1154 | logger.info("Weights loaded successfully") |
1127 | 1155 |
|
1128 | 1156 | # ------------------------------------------------------------------------- |
|
0 commit comments