|
1 | 1 | from abc import ABC, abstractmethod |
2 | 2 | from omegaconf import DictConfig, OmegaConf |
| 3 | +from pathlib import Path |
3 | 4 | import timm |
4 | 5 | from timm.data import resolve_data_config |
5 | 6 | from timm.data.transforms_factory import create_transform |
@@ -256,6 +257,69 @@ def extract_embedding(src, pretrained_model, task_type="linear_probing"): |
256 | 257 | return model, transform, extract_embedding |
257 | 258 |
|
258 | 259 |
|
| 260 | +def get_model_from_name(model_name: str, device: str): |
| 261 | + """Loading pretrained model from input name. |
| 262 | +
|
| 263 | + The list of all available models: |
| 264 | + * uni |
| 265 | + * uni2h |
| 266 | + * virchow |
| 267 | + * virchow2 |
| 268 | + * hoptimus0 |
| 269 | + * hoptimus1 |
| 270 | + * conch |
| 271 | + * titan |
| 272 | + * phikon |
| 273 | + * phikon2 |
| 274 | + * hiboub |
| 275 | + * hiboul |
| 276 | + * midnight |
| 277 | + * keep |
| 278 | + * quiltb32 |
| 279 | + * plip |
| 280 | + * musk |
| 281 | + * dinov2base |
| 282 | + * dinov2large |
| 283 | + * vitbasepatch16224in21k |
| 284 | + * vitlargepatch16224in21k |
| 285 | + * clipvitbasepatch32 |
| 286 | + * clipvitlargepatch14 |
| 287 | +
|
| 288 | + Args: |
| 289 | + model_name (str): The name of the model to use. |
| 290 | + device (str): Device to use (cpu, cuda). |
| 291 | +
|
| 292 | + Returns: |
| 293 | + model (torch.nn.Module): Pytorch model instance. |
| 294 | + transform (torchvision.transforms.transforms.Compose): Transform to apply to input image. |
| 295 | + get_embeddings (Callable): Function to extract embeddings. |
| 296 | +
|
| 297 | + Tip: output function `get_embeddings` signature. |
| 298 | + * src (torch.Tensor): Batch of transformed images with shape (B, 3, H, W). |
| 299 | + * pretrained_model (torch.nn.Module): Model to extract embeddings with. |
| 300 | + * pooled_emb (bool): Whether to output pooled (True) or spatial (False) embeddings. |
| 301 | + """ |
| 302 | + |
| 303 | + # Loading model config |
| 304 | + yaml_file = ( |
| 305 | + f"{Path(__file__).parent.parent}/config/pretrained_model/{model_name}.yaml" |
| 306 | + ) |
| 307 | + model_cfg = OmegaConf.load(yaml_file) |
| 308 | + |
| 309 | + # Getting model, transform, embedding extraction function |
| 310 | + model, transform, extract_embedding = get_model(model_cfg, device) |
| 311 | + |
| 312 | + # Defining wrapper function to get embeddings |
| 313 | + def get_embeddings(src, pretrained_model, pooled_emb=True): |
| 314 | + return extract_embedding( |
| 315 | + src, |
| 316 | + pretrained_model, |
| 317 | + task_type="linear_probing" if pooled_emb else "segmentation", |
| 318 | + ) |
| 319 | + |
| 320 | + return model, transform, get_embeddings |
| 321 | + |
| 322 | + |
259 | 323 | def get_from_timm(hf_tag: str, timm_kwargs: dict, ckpt_path: str, device: str): |
260 | 324 | """ |
261 | 325 | Adapted from: |
|
0 commit comments