Skip to content

Commit c9fdf1c

Browse files
Merge branch 'main' into renovate/uv_build-0.x
2 parents 209bfac + 806ce8e commit c9fdf1c

3 files changed

Lines changed: 9 additions & 17 deletions

File tree

tests/test_pipeline.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,7 @@ def run_full_pipeline(
152152
categorical_embedding_dims=model_params["categorical_embedding_dims"],
153153
num_classes=model_params["num_classes"],
154154
attention_config=attention_config,
155-
label_attention_config=(
156-
LabelAttentionConfig(
157-
n_head=attention_config.n_head,
158-
num_classes=model_params["num_classes"],
159-
)
160-
if label_attention_enabled
161-
else None
162-
),
155+
n_heads_label_attention=attention_config.n_head,
163156
)
164157

165158
# Create training config

torchTextClassifiers/tokenizers/WordPiece.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def train(
8282
self._post_training()
8383

8484
if save_path:
85-
self.tokenizer.save(save_path)
85+
self.tokenizer.save_pretrained(save_path)
8686
logger.info(f"💾 Tokenizer saved at {save_path}")
8787
if filesystem and s3_save_path:
8888
parent_dir = os.path.dirname(save_path)

torchTextClassifiers/torchTextClassifiers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ class ModelConfig:
5050
"""Base configuration class for text classifiers."""
5151

5252
embedding_dim: int
53+
num_classes: int
5354
categorical_vocabulary_sizes: Optional[List[int]] = None
5455
categorical_embedding_dims: Optional[Union[List[int], int]] = None
55-
num_classes: Optional[int] = None
5656
attention_config: Optional[AttentionConfig] = None
57-
label_attention_config: Optional[LabelAttentionConfig] = None
57+
n_heads_label_attention: Optional[int] = None
5858

5959
def to_dict(self) -> Dict[str, Any]:
6060
return asdict(self)
@@ -142,7 +142,7 @@ def __init__(
142142
self.embedding_dim = model_config.embedding_dim
143143
self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
144144
self.num_classes = model_config.num_classes
145-
self.enable_label_attention = model_config.label_attention_config is not None
145+
self.enable_label_attention = model_config.n_heads_label_attention is not None
146146

147147
if self.tokenizer.output_vectorized:
148148
self.text_embedder = None
@@ -156,7 +156,10 @@ def __init__(
156156
embedding_dim=self.embedding_dim,
157157
padding_idx=tokenizer.padding_idx,
158158
attention_config=model_config.attention_config,
159-
label_attention_config=model_config.label_attention_config,
159+
label_attention_config=LabelAttentionConfig(
160+
n_head=model_config.n_heads_label_attention,
161+
num_classes=model_config.num_classes,
162+
),
160163
)
161164
self.text_embedder = TextEmbedder(
162165
text_embedder_config=text_embedder_config,
@@ -697,10 +700,6 @@ def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassif
697700

698701
# Reconstruct model_config
699702
model_config = ModelConfig.from_dict(metadata["model_config"])
700-
if isinstance(model_config.label_attention_config, dict):
701-
model_config.label_attention_config = LabelAttentionConfig(
702-
**model_config.label_attention_config
703-
)
704703

705704
# Create instance
706705
instance = cls(

0 commit comments

Comments
 (0)