Skip to content

Commit 8d36627

Browse files
author
Raktim Mitra
committed
foundry install
1 parent ebc072c commit 8d36627

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

src/foundry/inference_engines/checkpoint_registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def get_default_path(self):
7878

7979

8080
REGISTERED_CHECKPOINTS = {
81+
"rfd3na": RegisteredCheckpoint(
82+
url="https://files.ipd.uw.edu/pub/rfdiffusion3na/rfd3na-1190.ckpt",
83+
filename="rfd3na_1190.ckpt",
84+
description="RFdiffusion3NA checkpoint",
85+
),
86+
8187
"rfd3": RegisteredCheckpoint(
8288
url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
8389
filename="rfd3_latest.ckpt",

src/foundry_cli/download_checkpoints.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) ->
107107
"""Install a single model checkpoint.
108108
109109
Args:
110-
model_name: Name of the model (rfd3, rf3, mpnn)
110+
model_name: Name of the model (rfd3, rfd3na, rf3, mpnn)
111111
checkpoint_dir: Directory to save checkpoints
112112
force: Overwrite existing checkpoint if it exists
113113
"""
@@ -145,7 +145,7 @@ def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) ->
145145
def install(
146146
models: list[str] = typer.Argument(
147147
...,
148-
help="Models to install: 'all', 'rfd3', 'rf3', 'mpnn', or a combination thereof",
148+
help="Models to install: 'all', 'rfd3', 'rfd3na', 'rf3', 'mpnn', or a combination thereof",
149149
),
150150
checkpoint_dir: Optional[Path] = typer.Option(
151151
None,
@@ -160,7 +160,7 @@ def install(
160160
"""Install model checkpoints for foundry.
161161
Examples:
162162
foundry install all
163-
foundry install rfd3 rf3
163+
foundry install rfd3 rfd3na rf3
164164
foundry install proteinmpnn --checkpoint-dir ./checkpoints
165165
"""
166166
# Determine checkpoint directory
@@ -173,7 +173,7 @@ def install(
173173
if "all" in models:
174174
models_to_install = list(REGISTERED_CHECKPOINTS.keys())
175175
elif "base-models" in models:
176-
models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
176+
models_to_install = ["rfd3", "rfd3na", "proteinmpnn", "ligandmpnn", "rf3"]
177177
else:
178178
models_to_install = models
179179

0 commit comments

Comments
 (0)