@@ -50,11 +50,11 @@ class ModelConfig:
5050 """Base configuration class for text classifiers."""
5151
5252 embedding_dim : int
53+ num_classes : int
5354 categorical_vocabulary_sizes : Optional [List [int ]] = None
5455 categorical_embedding_dims : Optional [Union [List [int ], int ]] = None
55- num_classes : Optional [int ] = None
5656 attention_config : Optional [AttentionConfig ] = None
57- label_attention_config : Optional [LabelAttentionConfig ] = None
57+ n_heads_label_attention : Optional [int ] = None
5858
5959 def to_dict (self ) -> Dict [str , Any ]:
6060 return asdict (self )
@@ -142,7 +142,7 @@ def __init__(
142142 self .embedding_dim = model_config .embedding_dim
143143 self .categorical_vocabulary_sizes = model_config .categorical_vocabulary_sizes
144144 self .num_classes = model_config .num_classes
145- self .enable_label_attention = model_config .label_attention_config is not None
145+ self .enable_label_attention = model_config .n_heads_label_attention is not None
146146
147147 if self .tokenizer .output_vectorized :
148148 self .text_embedder = None
@@ -156,7 +156,10 @@ def __init__(
156156 embedding_dim = self .embedding_dim ,
157157 padding_idx = tokenizer .padding_idx ,
158158 attention_config = model_config .attention_config ,
159- label_attention_config = model_config .label_attention_config ,
159+ label_attention_config = LabelAttentionConfig (
160+ n_head = model_config .n_heads_label_attention ,
161+ num_classes = model_config .num_classes ,
162+ ),
160163 )
161164 self .text_embedder = TextEmbedder (
162165 text_embedder_config = text_embedder_config ,
@@ -697,10 +700,6 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif
697700
698701 # Reconstruct model_config
699702 model_config = ModelConfig .from_dict (metadata ["model_config" ])
700- if isinstance (model_config .label_attention_config , dict ):
701- model_config .label_attention_config = LabelAttentionConfig (
702- ** model_config .label_attention_config
703- )
704703
705704 # Create instance
706705 instance = cls (
0 commit comments