@@ -74,13 +74,10 @@ def train_probe(
7474 for param in pretrained_model .parameters ():
7575 param .requires_grad = False
7676 pretrained_model .to (device )
77- OmegaConf .set_struct (cfg , False )
78- cfg .pretrained_model = {"emb_dim" : model_cls .emb_dim }
7977 else :
8078 pretrained_model , transform , extract_embedding = load_pretrained_model (
8179 cfg , adaptation_type , device
8280 )
83-
8481 if adaptation_type == "lora" :
8582 pretrained_model = init_adapters (cfg , pretrained_model , device )
8683 pretrained_model .train ()
@@ -89,6 +86,15 @@ def train_probe(
8986 else :
9087 pretrained_model = transform = extract_embedding = None
9188
89+ # Setting cfg.pretrained_model if custom model
90+ if model_cls is not None :
91+ cfg_pretrained_model = {}
92+ cfg_pretrained_model ["emb_dim" ] = model_cls .emb_dim
93+ if hasattr (model_cls , "emb_dim_seg" ):
94+ cfg_pretrained_model ["emb_dim_seg" ] = model_cls .emb_dim_seg
95+ OmegaConf .set_struct (cfg , False )
96+ cfg .pretrained_model = cfg_pretrained_model
97+
9298 # Dict of hyperparameters to search
9399 hyperparams_dict = get_hyperaparams_dict (cfg )
94100
0 commit comments