|
82 | 82 | log = logging.getLogger(__name__) |
83 | 83 |
|
84 | 84 |
|
| 85 | +def _warn_configuration_mismatch_during_finetune( |
| 86 | + input_descriptor: dict, |
| 87 | + pretrained_descriptor: dict, |
| 88 | + model_branch: str = "Default", |
| 89 | +) -> None: |
| 90 | + """ |
| 91 | + Warn about configuration mismatches between input descriptor and pretrained model |
| 92 | + when fine-tuning without --use-pretrain-script option. |
| 93 | +
|
| 94 | + This function warns when configurations differ and state_dict initialization |
| 95 | + will only pick relevant keys from the pretrained model (e.g., first 6 layers |
| 96 | + from a 16-layer model). |
| 97 | +
|
| 98 | + Parameters |
| 99 | + ---------- |
| 100 | + input_descriptor : dict |
| 101 | + Descriptor configuration from input.json |
| 102 | + pretrained_descriptor : dict |
| 103 | + Descriptor configuration from pretrained model |
| 104 | + model_branch : str |
| 105 | + Model branch name for logging context |
| 106 | + """ |
| 107 | + if input_descriptor == pretrained_descriptor: |
| 108 | + return |
| 109 | + |
| 110 | + # Collect differences |
| 111 | + differences = [] |
| 112 | + |
| 113 | + # Check for keys that differ in values |
| 114 | + for key in input_descriptor: |
| 115 | + if key in pretrained_descriptor: |
| 116 | + if input_descriptor[key] != pretrained_descriptor[key]: |
| 117 | + differences.append( |
| 118 | + f" {key}: {input_descriptor[key]} (input) vs {pretrained_descriptor[key]} (pretrained)" |
| 119 | + ) |
| 120 | + else: |
| 121 | + differences.append(f" {key}: {input_descriptor[key]} (input only)") |
| 122 | + |
| 123 | + # Check for keys only in pretrained model |
| 124 | + for key in pretrained_descriptor: |
| 125 | + if key not in input_descriptor: |
| 126 | + differences.append( |
| 127 | + f" {key}: {pretrained_descriptor[key]} (pretrained only)" |
| 128 | + ) |
| 129 | + |
| 130 | + if differences: |
| 131 | + log.warning( |
| 132 | + f"Descriptor configuration mismatch detected between input.json and pretrained model " |
| 133 | + f"(branch '{model_branch}'). State dict initialization will only use compatible parameters " |
| 134 | + f"from the pretrained model. Mismatched configuration:\n" |
| 135 | + + "\n".join(differences) |
| 136 | + ) |
| 137 | + |
| 138 | + |
85 | 139 | class Trainer: |
86 | 140 | def __init__( |
87 | 141 | self, |
@@ -117,6 +171,8 @@ def __init__( |
117 | 171 | training_params = config["training"] |
118 | 172 | self.multi_task = "model_dict" in model_params |
119 | 173 | self.finetune_links = finetune_links |
| 174 | + # Store model params for finetune warning comparisons |
| 175 | + self.model_params = model_params |
120 | 176 | self.finetune_update_stat = False |
121 | 177 | self.model_keys = ( |
122 | 178 | list(model_params["model_dict"]) if self.multi_task else ["Default"] |
@@ -512,6 +568,37 @@ def collect_single_finetune_params( |
512 | 568 | ) |
513 | 569 |
|
514 | 570 | # collect model params from the pretrained model |
| 571 | + # First check for configuration mismatches and warn if needed |
| 572 | + pretrained_model_params = state_dict["_extra_state"]["model_params"] |
| 573 | + for model_key in self.model_keys: |
| 574 | + finetune_rule_single = self.finetune_links[model_key] |
| 575 | + _model_key_from = finetune_rule_single.get_model_branch() |
| 576 | + |
| 577 | + # Get current model descriptor config |
| 578 | + if self.multi_task: |
| 579 | + current_descriptor = self.model_params["model_dict"][ |
| 580 | + model_key |
| 581 | + ].get("descriptor", {}) |
| 582 | + else: |
| 583 | + current_descriptor = self.model_params.get("descriptor", {}) |
| 584 | + |
| 585 | + # Get pretrained model descriptor config |
| 586 | + if "model_dict" in pretrained_model_params: |
| 587 | + pretrained_descriptor = pretrained_model_params[ |
| 588 | + "model_dict" |
| 589 | + ][_model_key_from].get("descriptor", {}) |
| 590 | + else: |
| 591 | + pretrained_descriptor = pretrained_model_params.get( |
| 592 | + "descriptor", {} |
| 593 | + ) |
| 594 | + |
| 595 | + # Warn about configuration mismatches |
| 596 | + _warn_configuration_mismatch_during_finetune( |
| 597 | + current_descriptor, |
| 598 | + pretrained_descriptor, |
| 599 | + _model_key_from, |
| 600 | + ) |
| 601 | + |
515 | 602 | for model_key in self.model_keys: |
516 | 603 | finetune_rule_single = self.finetune_links[model_key] |
517 | 604 | collect_single_finetune_params( |
|
0 commit comments