File tree Expand file tree Collapse file tree
torchTextClassifiers/model/components Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -35,7 +35,15 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
3535 if isinstance (self .attention_config , dict ):
3636 self .attention_config = AttentionConfig (** self .attention_config )
3737
38- self .enable_label_attention = text_embedder_config .label_attention_config is not None
38+ # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
39+ self .label_attention_config = text_embedder_config .label_attention_config
40+ if isinstance (self .label_attention_config , dict ):
41+ self .label_attention_config = LabelAttentionConfig (** self .label_attention_config )
42+ # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
43+ # always see a LabelAttentionConfig instance rather than a raw dict.
44+ self .config .label_attention_config = self .label_attention_config
45+
46+ self .enable_label_attention = self .label_attention_config is not None
3947 if self .enable_label_attention :
4048 self .label_attention_module = LabelAttentionClassifier (self .config )
4149
You can’t perform that action at this time.
0 commit comments