Skip to content

Commit e1a5ce5

Browse files
committed
Merge branch 'main' into vocquant
2 parents 87f61fe + 4867cb8 commit e1a5ce5

4 files changed

Lines changed: 40 additions & 5 deletions

File tree

model2vec/tokenizer/normalizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,18 @@ def replace_normalizer(
1818
:param tokenizer: The tokenizer to change.
1919
:return: The tokenizer with a replaced normalizer.
2020
"""
21+
spaces_punctuation = tokenizer.encode("a, ,", add_special_tokens=False).tokens
22+
if len(spaces_punctuation) != 3:
23+
add_space = False
24+
else:
25+
_, first_comma, second_comma = spaces_punctuation
26+
add_space = first_comma == second_comma == ","
27+
2128
normalizer = tokenizer.normalizer
2229
new_normalizers = []
2330
for char in punctuation:
24-
new_normalizers.append(Replace(char, f" {char} "))
31+
replacement = f" {char} " if add_space else f"{char} "
32+
new_normalizers.append(Replace(char, replacement))
2533

2634
new_normalizers.append(Replace(Regex(r"\s+"), " "))
2735
new_normalizers.append(Strip(right=True))

model2vec/train/classifier.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def fit(
138138
device: str = "auto",
139139
X_val: list[str] | None = None,
140140
y_val: LabelType | None = None,
141+
class_weight: torch.Tensor | None = None,
141142
) -> StaticModelForClassification:
142143
"""
143144
Fit a model.
@@ -165,6 +166,8 @@ def fit(
165166
:param device: The device to train on. If this is "auto", the device is chosen automatically.
166167
:param X_val: The texts to be used for validation.
167168
:param y_val: The labels to be used for validation.
169+
:param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
170+
have the same length as the number of classes.
168171
:return: The fitted model.
169172
:raises ValueError: If either X_val or y_val are provided, but not both.
170173
"""
@@ -199,13 +202,17 @@ def fit(
199202
base_number = int(min(max(1, (len(train_texts) / 30) // 32), 16))
200203
batch_size = int(base_number * 32)
201204
logger.info("Batch size automatically set to %d.", batch_size)
205+
206+
if class_weight is not None:
207+
if len(class_weight) != len(self.classes_):
208+
raise ValueError("class_weight must have the same length as the number of classes.")
202209

203210
logger.info("Preparing train dataset.")
204211
train_dataset = self._prepare_dataset(train_texts, train_labels)
205212
logger.info("Preparing validation dataset.")
206213
val_dataset = self._prepare_dataset(validation_texts, validation_labels)
207214

208-
c = _ClassifierLightningModule(self, learning_rate=learning_rate)
215+
c = _ClassifierLightningModule(self, learning_rate=learning_rate, class_weight=class_weight)
209216

210217
n_train_batches = len(train_dataset) // batch_size
211218
callbacks: list[Callback] = []
@@ -243,6 +250,9 @@ def fit(
243250

244251
state_dict = {}
245252
for weight_name, weight in best_model_weights["state_dict"].items():
253+
if "loss_function" in weight_name:
254+
# Skip the loss function class weight as its not needed for predictions
255+
continue
246256
state_dict[weight_name.removeprefix("model.")] = weight
247257

248258
self.load_state_dict(state_dict)
@@ -374,12 +384,12 @@ def to_pipeline(self) -> StaticModelPipeline:
374384

375385

376386
class _ClassifierLightningModule(pl.LightningModule):
377-
def __init__(self, model: StaticModelForClassification, learning_rate: float) -> None:
387+
def __init__(self, model: StaticModelForClassification, learning_rate: float, class_weight: torch.Tensor | None = None) -> None:
378388
"""Initialize the LightningModule."""
379389
super().__init__()
380390
self.model = model
381391
self.learning_rate = learning_rate
382-
self.loss_function = nn.CrossEntropyLoss() if not model.multilabel else nn.BCEWithLogitsLoss()
392+
self.loss_function = nn.CrossEntropyLoss(weight=class_weight) if not model.multilabel else nn.BCEWithLogitsLoss(pos_weight=class_weight)
383393

384394
def forward(self, x: torch.Tensor) -> torch.Tensor:
385395
"""Simple forward pass."""

model2vec/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__version_triple__ = (0, 5, 0)
1+
__version_triple__ = (0, 6, 0)
22
__version__ = ".".join(map(str, __version_triple__))

tests/test_trainable.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,23 @@ def test_y_val_none() -> None:
174174
model.fit(X, y, X_val=None, y_val=y_val)
175175
model.fit(X, y, X_val=None, y_val=None)
176176

177+
def test_class_weight() -> None:
178+
"""Test the class weight function."""
179+
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
180+
torch.random.manual_seed(42)
181+
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
182+
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
183+
184+
X = ["dog", "cat"]
185+
y = ["0", "1"]
186+
187+
bad_class_weight = torch.tensor([1.0])
188+
with pytest.raises(ValueError):
189+
model.fit(X, y, class_weight=bad_class_weight)
190+
191+
class_weight = torch.tensor([1.0, 2.0])
192+
model.fit(X, y, class_weight=class_weight)
193+
177194

178195
@pytest.mark.parametrize(
179196
"y_multi,y_val_multi,should_crash",

0 commit comments

Comments
 (0)