diff --git a/openfold3/core/model/primitives/initialization.py b/openfold3/core/model/primitives/initialization.py index 3c28d788..51e143fa 100644 --- a/openfold3/core/model/primitives/initialization.py +++ b/openfold3/core/model/primitives/initialization.py @@ -16,20 +16,13 @@ """Initialization functions for network parameters.""" import math +from functools import lru_cache -import numpy as np import torch from scipy.stats import truncnorm from torch import nn -def _prod(nums): - out = 1 - for n in nums: - out = out * n - return out - - def _calculate_fan(linear_weight_shape, fan="fan_in"): fan_out, fan_in = linear_weight_shape @@ -45,18 +38,26 @@ def _calculate_fan(linear_weight_shape, fan="fan_in"): return f +@lru_cache +def _cached_truncnorm_std(a, b, loc, scale): + return truncnorm.std(a, b, loc, scale) + + def trunc_normal_init_(weights, scale=1.0, fan="fan_in"): - shape = weights.shape - f = _calculate_fan(shape, fan) + f = _calculate_fan(weights.shape, fan) + scale = scale / max(1, f) - a = -2 - b = 2 - std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1) - size = _prod(shape) - samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size) - samples = np.reshape(samples, shape) + # truncnorm.std is always 0.8796256610342398 + std = float(math.sqrt(scale) / _cached_truncnorm_std(a=-2, b=2, loc=0, scale=1)) + with torch.no_grad(): - weights.copy_(torch.tensor(samples, device=weights.device)) + nn.init.trunc_normal_( + weights, + mean=0.0, + std=std, + a=-2.0 * std, + b=2.0 * std, + ) def lecun_normal_init_(weights): diff --git a/openfold3/tests/test_model_checkpoint.py b/openfold3/tests/test_model_checkpoint.py index 8df7f2d4..3450580a 100644 --- a/openfold3/tests/test_model_checkpoint.py +++ b/openfold3/tests/test_model_checkpoint.py @@ -20,7 +20,10 @@ import torch from openfold3.core.config import config_utils -from openfold3.core.utils.checkpoint_loading_utils import load_checkpoint +from openfold3.core.utils.checkpoint_loading_utils import ( + get_state_dict_from_checkpoint, + load_checkpoint, +) from openfold3.entry_points.experiment_runner import ( InferenceExperimentRunner, TrainingExperimentRunner, @@ -73,24 +76,24 @@ def test_make_model_ckpt( loss_module: diffusion: chunk_size: 16 - + dataset_configs: train: weighted-pdb: - dataset_class: WeightedPDBDataset - weight: 1 - + dataset_class: WeightedPDBDataset + weight: 1 + dataset_paths: weighted-pdb: alignments_directory: null alignment_db_directory: null - alignment_array_directory: {tmp_path} - target_structures_directory: {tmp_path} + alignment_array_directory: {tmp_path} + target_structures_directory: {tmp_path} target_structure_file_format: npz - dataset_cache_file: {test_dummy_file} - reference_molecule_directory: {tmp_path} - template_cache_directory: {tmp_path} - template_structure_array_directory: {tmp_path} + dataset_cache_file: {test_dummy_file} + reference_molecule_directory: {tmp_path} + template_cache_directory: {tmp_path} + template_structure_array_directory: {tmp_path} template_structures_directory: null template_file_format: pkl ccd_file: null @@ -169,3 +172,32 @@ def test_load_model_ckpt_with_missing_fields_fails( inference_runner = InferenceExperimentRunner(inference_config) with pytest.raises(RuntimeError): inference_runner.setup() + + def test_inference_load_state_dict_benchmark_under_ten_seconds( + self, benchmark, default_ckpt_path + ): + """Guard against regressions in load_state_dict latency during setup.""" + ckpt = load_checkpoint(default_ckpt_path) + state_dict, _ = get_state_dict_from_checkpoint(ckpt, init_from_ema_weights=True) + inference_config = InferenceExperimentConfig.model_validate( + {"inference_ckpt_path": default_ckpt_path} + ) + inference_runner = InferenceExperimentRunner(inference_config) + + def _load_state_dict(): + inference_runner._warn_on_missing_version_tensor_in_load_statedict( + state_dict + ) + + benchmark.pedantic( + _load_state_dict, + rounds=1, + iterations=1, + warmup_rounds=0, + ) + setup_seconds = benchmark.stats.stats.mean + + assert setup_seconds < 10.0, ( + f"InferenceExperimentRunner.lightning_module.load_state_dict path took {setup_seconds:.2f}s; " + "expected < 10.0s" + )