diff --git a/.gitignore b/.gitignore index 42d6f0ed..9cae7ea8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__/ #data data.json +data.json* #logs logs/ diff --git a/configs/config.yaml b/configs/config.yaml index d1456c9e..006f943f 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -23,6 +23,8 @@ loss_alpha: 0.75 loss_gamma: 0 label_smoothing: 0 loss_reduction: "sum" +negative_rate: 0.75 +neg_span_masking : "global_w_threshold" # Learning Rate and weight decay Configuration lr_encoder: 1e-5 diff --git a/gliner/config.py b/gliner/config.py index c7af214d..a5fe60ec 100644 --- a/gliner/config.py +++ b/gliner/config.py @@ -22,6 +22,7 @@ def __init__(self, max_types: int = 25, max_len: int = 384, words_splitter_type: str = "whitespace", + neg_span_masking: str = None, has_rnn: bool = True, fuse_layers: bool = False, embed_ent_token: bool = True, @@ -68,6 +69,7 @@ def __init__(self, self.embed_ent_token = embed_ent_token self.ent_token = ent_token self.sep_token = sep_token + self.neg_span_masking=neg_span_masking # Register the configuration from transformers import CONFIG_MAPPING diff --git a/gliner/modeling/base.py b/gliner/modeling/base.py index bcd04cdc..ebfa6a4d 100644 --- a/gliner/modeling/base.py +++ b/gliner/modeling/base.py @@ -257,7 +257,7 @@ def forward(self, def loss(self, scores, labels, prompts_embedding_mask, mask_label, alpha: float = -1., gamma: float = 0.0, label_smoothing: float = 0.0, - reduction: str = 'sum', **kwargs): + reduction: str = 'sum', negative_rate: float = 0.75, neg_span_masking: str = None, **kwargs): batch_size = scores.shape[0] num_classes = prompts_embedding_mask.shape[-1] @@ -274,6 +274,40 @@ def loss(self, scores, labels, prompts_embedding_mask, mask_label, all_losses = all_losses * mask_label.float() + if neg_span_masking is not None : + + if neg_span_masking == "global_w_threshold": + + mask_negative_examples = (torch.rand_like(labels, dtype=torch.float) + labels > negative_rate).float() + all_losses = all_losses * mask_negative_examples + + elif neg_span_masking == "global_wo_threshold" : + + p = torch.sigmoid(scores) + random_mask = torch.bernoulli(1 - p) + labels + mask_negative_examples = torch.where(labels == 1, torch.ones_like(labels), random_mask) + all_losses = all_losses*mask_negative_examples + + elif neg_span_masking == "entity_w_threshold": + + mask_negative_examples = labels.clone() + zero_rows = labels.sum(dim=1) == 0 + mask_negative_examples[zero_rows] = (torch.rand((zero_rows.sum(), labels.size(1))) >= negative_rate).float() + all_losses = all_losses*mask_negative_examples + + elif neg_span_masking == "entity_wo_threshold": + + p = torch.sigmoid(scores) + mask = labels.clone() + rows_to_sample = labels.sum(dim=1) == 0 + mask[rows_to_sample] = torch.bernoulli(p[rows_to_sample]) + + else: + + warnings.warn( + f"Invalid Value for config 'neg_span_masking': '{neg_span_masking}. ") + + if reduction == "mean": loss = all_losses.mean() elif reduction == 'sum':