Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions openfold3/core/model/primitives/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,13 @@
"""Initialization functions for network parameters."""

import math
from functools import lru_cache

import numpy as np
import torch
from scipy.stats import truncnorm
from torch import nn


def _prod(nums):
out = 1
for n in nums:
out = out * n
return out


def _calculate_fan(linear_weight_shape, fan="fan_in"):
fan_out, fan_in = linear_weight_shape

Expand All @@ -45,18 +38,26 @@ def _calculate_fan(linear_weight_shape, fan="fan_in"):
return f


@lru_cache
def _cached_truncnorm_std(a, b, loc, scale):
return truncnorm.std(a, b, loc, scale)


def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
f = _calculate_fan(shape, fan)
f = _calculate_fan(weights.shape, fan)

scale = scale / max(1, f)
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
# truncnorm.std is always 0.8796256610342398
std = float(math.sqrt(scale) / _cached_truncnorm_std(a=-2, b=2, loc=0, scale=1))

with torch.no_grad():
weights.copy_(torch.tensor(samples, device=weights.device))
nn.init.trunc_normal_(
weights,
mean=0.0,
std=std,
a=-2.0 * std,
b=2.0 * std,
)


def lecun_normal_init_(weights):
Expand Down
54 changes: 43 additions & 11 deletions openfold3/tests/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import torch

from openfold3.core.config import config_utils
from openfold3.core.utils.checkpoint_loading_utils import load_checkpoint
from openfold3.core.utils.checkpoint_loading_utils import (
get_state_dict_from_checkpoint,
load_checkpoint,
)
from openfold3.entry_points.experiment_runner import (
InferenceExperimentRunner,
TrainingExperimentRunner,
Expand Down Expand Up @@ -73,24 +76,24 @@ def test_make_model_ckpt(
loss_module:
diffusion:
chunk_size: 16

dataset_configs:
train:
weighted-pdb:
dataset_class: WeightedPDBDataset
weight: 1
dataset_class: WeightedPDBDataset
weight: 1

dataset_paths:
weighted-pdb:
alignments_directory: null
alignment_db_directory: null
alignment_array_directory: {tmp_path}
target_structures_directory: {tmp_path}
alignment_array_directory: {tmp_path}
target_structures_directory: {tmp_path}
target_structure_file_format: npz
dataset_cache_file: {test_dummy_file}
reference_molecule_directory: {tmp_path}
template_cache_directory: {tmp_path}
template_structure_array_directory: {tmp_path}
dataset_cache_file: {test_dummy_file}
reference_molecule_directory: {tmp_path}
template_cache_directory: {tmp_path}
template_structure_array_directory: {tmp_path}
template_structures_directory: null
template_file_format: pkl
ccd_file: null
Expand Down Expand Up @@ -169,3 +172,32 @@ def test_load_model_ckpt_with_missing_fields_fails(
inference_runner = InferenceExperimentRunner(inference_config)
with pytest.raises(RuntimeError):
inference_runner.setup()

def test_inference_load_state_dict_benchmark_under_ten_seconds(
self, benchmark, default_ckpt_path
):
"""Guard against regressions in load_state_dict latency during setup."""
ckpt = load_checkpoint(default_ckpt_path)
state_dict, _ = get_state_dict_from_checkpoint(ckpt, init_from_ema_weights=True)
inference_config = InferenceExperimentConfig.model_validate(
{"inference_ckpt_path": default_ckpt_path}
)
inference_runner = InferenceExperimentRunner(inference_config)

def _load_state_dict():
inference_runner._warn_on_missing_version_tensor_in_load_statedict(
state_dict
)

benchmark.pedantic(
_load_state_dict,
rounds=1,
iterations=1,
warmup_rounds=0,
)
setup_seconds = benchmark.stats.stats.mean

assert setup_seconds < 10.0, (
f"InferenceExperimentRunner.lightning_module.load_state_dict path took {setup_seconds:.2f}s; "
"expected < 10.0s"
)