Skip to content

Commit 1aa03a3

Browse files
committed
API 'get_model_from_name' function to get a model from string name
1 parent 1d39571 commit 1aa03a3

3 files changed

Lines changed: 67 additions & 1 deletion

File tree

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
::: thunder.generate_splits
88

99
::: thunder.models.PretrainedModel
10+
11+
::: thunder.models.get_model_from_name

src/thunder/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .download import download_models
2-
from .pretrained_models import PretrainedModel
2+
from .pretrained_models import get_model_from_name, PretrainedModel

src/thunder/models/pretrained_models.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from abc import ABC, abstractmethod
22
from omegaconf import DictConfig, OmegaConf
3+
from pathlib import Path
34
import timm
45
from timm.data import resolve_data_config
56
from timm.data.transforms_factory import create_transform
@@ -256,6 +257,69 @@ def extract_embedding(src, pretrained_model, task_type="linear_probing"):
256257
return model, transform, extract_embedding
257258

258259

260+
def get_model_from_name(model_name: str, device: str):
261+
"""Loading pretrained model from input name.
262+
263+
The list of all available models:
264+
* uni
265+
* uni2h
266+
* virchow
267+
* virchow2
268+
* hoptimus0
269+
* hoptimus1
270+
* conch
271+
* titan
272+
* phikon
273+
* phikon2
274+
* hiboub
275+
* hiboul
276+
* midnight
277+
* keep
278+
* quiltb32
279+
* plip
280+
* musk
281+
* dinov2base
282+
* dinov2large
283+
* vitbasepatch16224in21k
284+
* vitlargepatch16224in21k
285+
* clipvitbasepatch32
286+
* clipvitlargepatch14
287+
288+
Args:
289+
model_name (str): The name of the model to use.
290+
device (str): Device to use (cpu, cuda).
291+
292+
Returns:
293+
model (torch.nn.Module): Pytorch model instance.
294+
transform (torchvision.transforms.transforms.Compose): Transform to apply to input image.
295+
get_embeddings (Callable): Function to extract embeddings.
296+
297+
Tip: output function `get_embeddings` signature.
298+
* src (torch.Tensor): Batch of transformed images with shape (B, 3, H, W).
299+
* pretrained_model (torch.nn.Module): Model to extract embeddings with.
300+
* pooled_emb (bool): Whether to output pooled (True) or spatial (False) embeddings.
301+
"""
302+
303+
# Loading model config
304+
yaml_file = (
305+
f"{Path(__file__).parent.parent}/config/pretrained_model/{model_name}.yaml"
306+
)
307+
model_cfg = OmegaConf.load(yaml_file)
308+
309+
# Getting model, transform, embedding extraction function
310+
model, transform, extract_embedding = get_model(model_cfg, device)
311+
312+
# Defining wrapper function to get embeddings
313+
def get_embeddings(src, pretrained_model, pooled_emb=True):
314+
return extract_embedding(
315+
src,
316+
pretrained_model,
317+
task_type="linear_probing" if pooled_emb else "segmentation",
318+
)
319+
320+
return model, transform, get_embeddings
321+
322+
259323
def get_from_timm(hf_tag: str, timm_kwargs: dict, ckpt_path: str, device: str):
260324
"""
261325
Adapted from:

0 commit comments

Comments
 (0)