Skip to content

Commit 4577614

Browse files
chore!: avoid repeating num_classes for LabelAttentionConfig
not expose LabelAttentionConfig and build it directly from the wrapper only provide num_heads
1 parent 4a56bd5 commit 4577614

1 file changed

Lines changed: 7 additions & 8 deletions

File tree

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ class ModelConfig:
5050
"""Base configuration class for text classifiers."""
5151

5252
embedding_dim: int
53+
num_classes: int
5354
categorical_vocabulary_sizes: Optional[List[int]] = None
5455
categorical_embedding_dims: Optional[Union[List[int], int]] = None
55-
num_classes: Optional[int] = None
5656
attention_config: Optional[AttentionConfig] = None
57-
label_attention_config: Optional[LabelAttentionConfig] = None
57+
n_heads_label_attention: Optional[int] = None
5858

5959
def to_dict(self) -> Dict[str, Any]:
6060
return asdict(self)
@@ -142,7 +142,7 @@ def __init__(
142142
self.embedding_dim = model_config.embedding_dim
143143
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
144144
self.num_classes = model_config.num_classes
145-
self.enable_label_attention = model_config.label_attention_config is not None
145+
self.enable_label_attention = model_config.n_heads_label_attention is not None
146146

147147
if self.tokenizer.output_vectorized:
148148
self.text_embedder = None
@@ -156,7 +156,10 @@ def __init__(
156156
embedding_dim=self.embedding_dim,
157157
padding_idx=tokenizer.padding_idx,
158158
attention_config=model_config.attention_config,
159-
label_attention_config=model_config.label_attention_config,
159+
label_attention_config=LabelAttentionConfig(
160+
n_head=model_config.n_heads_label_attention,
161+
num_classes=model_config.num_classes,
162+
),
160163
)
161164
self.text_embedder = TextEmbedder(
162165
text_embedder_config=text_embedder_config,
@@ -697,10 +700,6 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif
697700

698701
# Reconstruct model_config
699702
model_config = ModelConfig.from_dict(metadata["model_config"])
700-
if isinstance(model_config.label_attention_config, dict):
701-
model_config.label_attention_config = LabelAttentionConfig(
702-
**model_config.label_attention_config
703-
)
704703

705704
# Create instance
706705
instance = cls(

0 commit comments

Comments
 (0)