|
42 | 42 | MegatronTrainingJob, |
43 | 43 | ) |
44 | 44 | from art.megatron.lora import apply_lora_adapters |
| 45 | +from art.megatron.merge import merge_lora_adapter |
45 | 46 | from art.megatron.offload import ( |
46 | 47 | OffloadState, |
47 | 48 | clear_optimizer_state, |
@@ -402,7 +403,7 @@ def run_megatron_rl_job( |
402 | 403 | learning_rate=job.config.learning_rate, |
403 | 404 | inputs=micro_inputs, |
404 | 405 | config=job.config, |
405 | | - experimental_config=job.experimental_config, |
| 406 | + experimental_config=cast(dev.TrainConfig, job.experimental_config), |
406 | 407 | ref_logprobs=None, |
407 | 408 | step_index=step_index, |
408 | 409 | sample_index=micro_indices, |
@@ -587,97 +588,6 @@ def _job_cleanup_path(job: MegatronJob) -> str: |
587 | 588 | return job.disk_packed_tensors["dir"] |
588 | 589 |
|
589 | 590 |
|
590 | | -def merge_lora_adapter(lora_path: str) -> None: |
591 | | - base_dir = Path(lora_path) |
592 | | - shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors")) |
593 | | - if not shard_filenames: |
594 | | - return |
595 | | - |
596 | | - shard_files_by_suffix = { |
597 | | - path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path |
598 | | - for path in shard_filenames |
599 | | - } |
600 | | - manifest_filenames = sorted(base_dir.glob("adapter_manifest-*-of-*.json")) |
601 | | - manifest_files_by_suffix = { |
602 | | - path.name.removeprefix("adapter_manifest-").removesuffix(".json"): path |
603 | | - for path in manifest_filenames |
604 | | - } |
605 | | - |
606 | | - if set(shard_files_by_suffix) != set(manifest_files_by_suffix): |
607 | | - raise RuntimeError( |
608 | | - "Shard/manifest coverage mismatch: " |
609 | | - f"shards={sorted(shard_files_by_suffix)}, " |
610 | | - f"manifests={sorted(manifest_files_by_suffix)}" |
611 | | - ) |
612 | | - |
613 | | - entries_by_key: dict[str, list[tuple[dict[str, Any], torch.Tensor]]] = {} |
614 | | - for suffix in sorted(shard_files_by_suffix): |
615 | | - shard_path = shard_files_by_suffix[suffix] |
616 | | - manifest_path = manifest_files_by_suffix[suffix] |
617 | | - with open(manifest_path, "r", encoding="utf-8") as manifest_file: |
618 | | - shard_manifest: dict[str, dict[str, Any]] = json.load(manifest_file) |
619 | | - with safe_open(shard_path, framework="pt") as file: |
620 | | - shard_tensors = {key: file.get_tensor(key) for key in file.keys()} |
621 | | - |
622 | | - if set(shard_tensors) != set(shard_manifest): |
623 | | - raise RuntimeError( |
624 | | - f"Tensor/manifest key mismatch for shard suffix={suffix}: " |
625 | | - f"tensor_keys={sorted(shard_tensors)}, " |
626 | | - f"manifest_keys={sorted(shard_manifest)}" |
627 | | - ) |
628 | | - for key, tensor in shard_tensors.items(): |
629 | | - entries_by_key.setdefault(key, []).append((shard_manifest[key], tensor)) |
630 | | - |
631 | | - adapter_model: dict[str, torch.Tensor] = {} |
632 | | - for key, key_entries in entries_by_key.items(): |
633 | | - first_manifest = key_entries[0][0] |
634 | | - sharded = bool(first_manifest["sharded"]) |
635 | | - shard_world_size = int(first_manifest["shard_world_size"]) |
636 | | - for manifest_entry, _tensor in key_entries: |
637 | | - if bool(manifest_entry["sharded"]) != sharded: |
638 | | - raise RuntimeError(f"Inconsistent sharded flag for key={key}") |
639 | | - if int(manifest_entry["shard_world_size"]) != shard_world_size: |
640 | | - raise RuntimeError(f"Inconsistent shard world size for key={key}") |
641 | | - |
642 | | - if not sharded: |
643 | | - if len(key_entries) != 1: |
644 | | - raise RuntimeError( |
645 | | - f"Replicated key={key} expected 1 shard, got {len(key_entries)}" |
646 | | - ) |
647 | | - tensor = key_entries[0][1] |
648 | | - else: |
649 | | - shard_rank_to_tensor: dict[int, torch.Tensor] = {} |
650 | | - for manifest_entry, shard_tensor in key_entries: |
651 | | - shard_rank = int(manifest_entry["shard_rank"]) |
652 | | - if shard_rank in shard_rank_to_tensor: |
653 | | - raise RuntimeError( |
654 | | - f"Duplicate shard_rank={shard_rank} for key={key}" |
655 | | - ) |
656 | | - shard_rank_to_tensor[shard_rank] = shard_tensor |
657 | | - |
658 | | - expected_shard_ranks = set(range(shard_world_size)) |
659 | | - if set(shard_rank_to_tensor) != expected_shard_ranks: |
660 | | - raise RuntimeError( |
661 | | - f"Shard rank coverage mismatch for key={key}: " |
662 | | - f"expected {sorted(expected_shard_ranks)}, got {sorted(shard_rank_to_tensor)}" |
663 | | - ) |
664 | | - |
665 | | - ordered_shards = [ |
666 | | - shard_rank_to_tensor[shard_rank] |
667 | | - for shard_rank in range(shard_world_size) |
668 | | - ] |
669 | | - concat_dim = 1 if "lora_A" in key else 0 |
670 | | - tensor = torch.cat(ordered_shards, dim=concat_dim) |
671 | | - adapter_model[key] = tensor |
672 | | - |
673 | | - adapter_model_path = base_dir / "adapter_model.safetensors" |
674 | | - save_file(adapter_model, adapter_model_path) |
675 | | - for filename in shard_filenames: |
676 | | - filename.unlink() |
677 | | - for filename in manifest_filenames: |
678 | | - filename.unlink() |
679 | | - |
680 | | - |
681 | 591 | def _load_sft_batch_from_disk( |
682 | 592 | batch_dir: str, |
683 | 593 | ) -> tuple[dict[str, Any], list[dict[str, torch.Tensor]]]: |
|
0 commit comments