Skip to content

Commit cbb5c41

Browse files
committed
Added cfg setting for probe training with custom model
1 parent e7fffbe commit cbb5c41

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

src/thunder/tasks/train_eval_probe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,10 @@ def train_probe(
7474
for param in pretrained_model.parameters():
7575
param.requires_grad = False
7676
pretrained_model.to(device)
77-
OmegaConf.set_struct(cfg, False)
78-
cfg.pretrained_model = {"emb_dim": model_cls.emb_dim}
7977
else:
8078
pretrained_model, transform, extract_embedding = load_pretrained_model(
8179
cfg, adaptation_type, device
8280
)
83-
8481
if adaptation_type == "lora":
8582
pretrained_model = init_adapters(cfg, pretrained_model, device)
8683
pretrained_model.train()
@@ -89,6 +86,15 @@ def train_probe(
8986
else:
9087
pretrained_model = transform = extract_embedding = None
9188

89+
# Setting cfg.pretrained_model if custom model
90+
if model_cls is not None:
91+
cfg_pretrained_model = {}
92+
cfg_pretrained_model["emb_dim"] = model_cls.emb_dim
93+
if hasattr(model_cls, "emb_dim_seg"):
94+
cfg_pretrained_model["emb_dim_seg"] = model_cls.emb_dim_seg
95+
OmegaConf.set_struct(cfg, False)
96+
cfg.pretrained_model = cfg_pretrained_model
97+
9298
# Dict of hyperparameters to search
9399
hyperparams_dict = get_hyperaparams_dict(cfg)
94100

0 commit comments

Comments
 (0)