@@ -138,6 +138,7 @@ def fit(
138138 device : str = "auto" ,
139139 X_val : list [str ] | None = None ,
140140 y_val : LabelType | None = None ,
141+ class_weight : torch .Tensor | None = None ,
141142 ) -> StaticModelForClassification :
142143 """
143144 Fit a model.
@@ -165,6 +166,8 @@ def fit(
165166 :param device: The device to train on. If this is "auto", the device is chosen automatically.
166167 :param X_val: The texts to be used for validation.
167168 :param y_val: The labels to be used for validation.
169+ :param class_weight: The weight of the classes. If None, all classes are weighted equally. Must
170+ have the same length as the number of classes.
168171 :return: The fitted model.
169172 :raises ValueError: If either X_val or y_val are provided, but not both.
170173 """
@@ -199,13 +202,17 @@ def fit(
199202 base_number = int (min (max (1 , (len (train_texts ) / 30 ) // 32 ), 16 ))
200203 batch_size = int (base_number * 32 )
201204 logger .info ("Batch size automatically set to %d." , batch_size )
205+
206+ if class_weight is not None :
207+ if len (class_weight ) != len (self .classes_ ):
208+ raise ValueError ("class_weight must have the same length as the number of classes." )
202209
203210 logger .info ("Preparing train dataset." )
204211 train_dataset = self ._prepare_dataset (train_texts , train_labels )
205212 logger .info ("Preparing validation dataset." )
206213 val_dataset = self ._prepare_dataset (validation_texts , validation_labels )
207214
208- c = _ClassifierLightningModule (self , learning_rate = learning_rate )
215+ c = _ClassifierLightningModule (self , learning_rate = learning_rate , class_weight = class_weight )
209216
210217 n_train_batches = len (train_dataset ) // batch_size
211218 callbacks : list [Callback ] = []
@@ -243,6 +250,9 @@ def fit(
243250
244251 state_dict = {}
245252 for weight_name , weight in best_model_weights ["state_dict" ].items ():
253+ if "loss_function" in weight_name :
254+ # Skip the loss function class weight as its not needed for predictions
255+ continue
246256 state_dict [weight_name .removeprefix ("model." )] = weight
247257
248258 self .load_state_dict (state_dict )
@@ -374,12 +384,12 @@ def to_pipeline(self) -> StaticModelPipeline:
374384
375385
376386class _ClassifierLightningModule (pl .LightningModule ):
377- def __init__ (self , model : StaticModelForClassification , learning_rate : float ) -> None :
387+ def __init__ (self , model : StaticModelForClassification , learning_rate : float , class_weight : torch . Tensor | None = None ) -> None :
378388 """Initialize the LightningModule."""
379389 super ().__init__ ()
380390 self .model = model
381391 self .learning_rate = learning_rate
382- self .loss_function = nn .CrossEntropyLoss () if not model .multilabel else nn .BCEWithLogitsLoss ()
392+ self .loss_function = nn .CrossEntropyLoss (weight = class_weight ) if not model .multilabel else nn .BCEWithLogitsLoss (pos_weight = class_weight )
383393
384394 def forward (self , x : torch .Tensor ) -> torch .Tensor :
385395 """Simple forward pass."""
0 commit comments