@@ -126,6 +126,8 @@ def get_model(model_cfg: dict, device: str):
126126 model , transform = get_titan (model_cfg .ckpt_path )
127127 elif model_cfg .model_name == "midnight" :
128128 model , transform = get_midnight (model_cfg .ckpt_path )
129+ elif model_cfg .model_name == "genbio-pathfm" :
130+ model , transform = get_genbio_pathfm (model_cfg .ckpt_path )
129131 else :
130132 model , transform , tokenizer = get_from_safetensors (
131133 model_cfg .ckpt_path , use_fast = "dinov3" in model_cfg .model_name
@@ -366,6 +368,8 @@ def extract_embedding(src, pretrained_model, task_type="linear_probing"):
366368 emb = pretrained_model .trunk (src , return_all_tokens = True )[:, 1 :]
367369 elif model_cfg .model_name == "openmidnight" :
368370 emb = pretrained_model .get_intermediate_layers (src )[0 ]
371+ elif model_cfg .model_name == "genbio-pathfm" :
372+ emb = pretrained_model .forward_with_patches (src )[1 ]
369373 else :
370374 emb = pretrained_model .forward_features (src )[:, 1 :]
371375
@@ -394,6 +398,7 @@ def get_model_from_name(model_name: str, device: str):
394398 * hoptimus0
395399 * h0mini
396400 * hoptimus1
401+ * genbio-pathfm
397402 * provgigapath
398403 * conch
399404 * titan
@@ -756,3 +761,36 @@ def get_titan(ckpt_path: str):
756761 model , transform = titan .return_conch ()
757762
758763 return model , transform
764+
765+
766+ def get_genbio_pathfm (ckpt_path : str ):
767+ """
768+ Adapted from:
769+ - https://github.com/genbio-ai/genbio-pathfm
770+
771+ :param ckpt_path: path to the stored checkpoint.
772+ """
773+ try :
774+ from genbio_pathfm .model import GenBio_PathFM_Inference
775+ except ImportError :
776+ raise ImportError (
777+ "In order to use GenBio-PathFM, please run the following: 'pip install git+https://github.com/genbio-ai/genbio-pathfm.git'"
778+ )
779+
780+ from torchvision import transforms
781+
782+ # Model
783+ model = GenBio_PathFM_Inference (ckpt_path , device = "cpu" )
784+
785+ # Transform
786+ transform = transforms .Compose ([
787+ transforms .Resize ((224 ,224 )),
788+ transforms .ToTensor (),
789+ transforms .Normalize (
790+ mean = (0.697 , 0.575 , 0.728 ),
791+ std = (0.188 , 0.240 , 0.187 )
792+ ),
793+ ])
794+
795+ return model , transform
796+
0 commit comments