Skip to content

Commit f991b6b

Browse files
fix: convert to LabelAttentionConfig object when dict
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent ec6742c commit f991b6b

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

torchTextClassifiers/model/components/text_embedder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
3535
if isinstance(self.attention_config, dict):
3636
self.attention_config = AttentionConfig(**self.attention_config)
3737

38-
self.enable_label_attention = text_embedder_config.label_attention_config is not None
38+
# Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
39+
self.label_attention_config = text_embedder_config.label_attention_config
40+
if isinstance(self.label_attention_config, dict):
41+
self.label_attention_config = LabelAttentionConfig(**self.label_attention_config)
42+
# Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
43+
# always see a LabelAttentionConfig instance rather than a raw dict.
44+
self.config.label_attention_config = self.label_attention_config
45+
46+
self.enable_label_attention = self.label_attention_config is not None
3947
if self.enable_label_attention:
4048
self.label_attention_module = LabelAttentionClassifier(self.config)
4149

0 commit comments

Comments
 (0)