Skip to content

Commit e93c046

Browse files
authored
Merge pull request #351 from maxwbuckley/fix-test-suite
Fix 35 pre-existing test failures; add CI workflow
2 parents e052cc0 + 7fd048f commit e93c046

27 files changed

Lines changed: 1305 additions & 1186 deletions

.github/workflows/tests.yml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
branches:
9+
- main
10+
workflow_dispatch:
11+
12+
concurrency:
13+
group: ${{ github.workflow }}-${{ github.ref }}
14+
cancel-in-progress: true
15+
16+
jobs:
17+
test:
18+
name: pytest (Python ${{ matrix.python-version }})
19+
runs-on: ubuntu-latest
20+
strategy:
21+
fail-fast: false
22+
matrix:
23+
python-version: ["3.10", "3.12"]
24+
25+
steps:
26+
- name: Check out repository
27+
uses: actions/checkout@v4
28+
29+
- name: Set up Python
30+
uses: actions/setup-python@v5
31+
with:
32+
python-version: ${{ matrix.python-version }}
33+
34+
- name: Cache pip
35+
uses: actions/cache@v4
36+
with:
37+
path: ~/.cache/pip
38+
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt', 'pyproject.toml') }}
39+
restore-keys: |
40+
${{ runner.os }}-py${{ matrix.python-version }}-pip-
41+
42+
- name: Install dependencies
43+
run: |
44+
python -m pip install --upgrade pip
45+
pip install -r requirements.txt
46+
pip install pytest pytest-asyncio sentencepiece onnxruntime
47+
48+
- name: Run pytest
49+
run: pytest -q --tb=short
50+
51+
lint:
52+
name: ruff
53+
runs-on: ubuntu-latest
54+
steps:
55+
- name: Check out repository
56+
uses: actions/checkout@v4
57+
58+
- name: Set up Python
59+
uses: actions/setup-python@v5
60+
with:
61+
python-version: "3.12"
62+
63+
- name: Install ruff
64+
run: pip install ruff
65+
66+
- name: ruff check
67+
run: ruff check gliner

gliner/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def __init__(
230230
the per-type entity drop probability. Defaults to (0.0, 0.4).
231231
augment_rel_drop_prob (tuple, optional): Range (min, max) from which to sample
232232
the per-type relation drop probability. Defaults to (0.0, 0.4).
233+
augment_add_other_prob (float, optional): Probability of adding "other" relation to a pair with no relation.
233234
rel_id_to_classes (Optional[dict]): Mapping from relation class IDs to class names. Defaults to None.
234235
**kwargs: Additional keyword arguments passed to UniEncoderConfig.
235236

gliner/data_processing/processor.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,8 @@ def create_labels(self, batch, blank=None):
13681368
if self.config.decoder_mode == "span":
13691369
# Collect entity labels in order of appearance
13701370
sorted_entities = sorted(ner, key=lambda x: (x[0], x[1])) if ner else []
1371-
for start, end, label in sorted_entities:
1371+
# start, end, label = entity
1372+
for _, end, label in sorted_entities:
13721373
if label in classes_to_id and end < num_tokens:
13731374
decoder_label_strings.append(label)
13741375
elif self.config.decoder_mode == "prompt":
@@ -1477,8 +1478,8 @@ def augment_example(self, example, ner_negatives=None, other_keyword="other"):
14771478
rel_drop_prob = random.uniform(*self.config.augment_rel_drop_prob)
14781479
add_other = random.random() < self.config.augment_add_other_prob
14791480

1480-
all_ent_types = set(e[-1] for e in ner)
1481-
all_rel_types = set(r[-1] for r in relations) if relations else set()
1481+
all_ent_types = {e[-1] for e in ner}
1482+
all_rel_types = {r[-1] for r in relations} if relations else set()
14821483

14831484
# "other" is exempt from dropping since it's our replacement label
14841485
dropped_ent_types = {t for t in all_ent_types if t != other_keyword and random.random() < ent_drop_prob}
@@ -1512,7 +1513,7 @@ def augment_example(self, example, ner_negatives=None, other_keyword="other"):
15121513
old_to_new_idx[i] = len(new_ner)
15131514
if ent_type in dropped_ent_types and add_other:
15141515
# Replace dropped type with "other"
1515-
new_ner.append(list(ent[:-1]) + [other_keyword])
1516+
new_ner.append([*ent[:-1], other_keyword])
15161517
else:
15171518
new_ner.append(ent)
15181519

@@ -1649,13 +1650,12 @@ def collate_raw_batch(
16491650
relation extraction.
16501651
"""
16511652
# Apply data augmentation if enabled (only during dynamic mapping generation)
1652-
augment_prob = getattr(self.config, 'augment_data_prob', 0.0)
1653+
augment_prob = getattr(self.config, "augment_data_prob", 0.0)
16531654
if augment_prob > 0.0 and class_to_ids is None and entity_types is None:
16541655
if ner_negatives is None:
16551656
ner_negatives = get_negatives(batch_list, sampled_neg=100, key="ner")
16561657
batch_list = [
1657-
self.augment_example(b, ner_negatives) if random.random() < augment_prob else b
1658-
for b in batch_list
1658+
self.augment_example(b, ner_negatives) if random.random() < augment_prob else b for b in batch_list
16591659
]
16601660
if class_to_ids is None and entity_types is None:
16611661
# Dynamically infer per-example mappings
@@ -1768,7 +1768,11 @@ def preprocess_example(self, tokens, ner, classes_to_id, relations, rel_classes_
17681768
head_idx, tail_idx, rel_type = rel
17691769

17701770
# Use compact indices so rel_idx aligns with target_span_rep positions
1771-
if head_idx in entity_to_compact_idx and tail_idx in entity_to_compact_idx and rel_type in rel_classes_to_id:
1771+
if (
1772+
head_idx in entity_to_compact_idx
1773+
and tail_idx in entity_to_compact_idx
1774+
and rel_type in rel_classes_to_id
1775+
):
17721776
rel_idx_list.append([entity_to_compact_idx[head_idx], entity_to_compact_idx[tail_idx]])
17731777
rel_label_list.append(rel_classes_to_id[rel_type])
17741778

@@ -1833,7 +1837,9 @@ def create_batch_dict(self, batch, class_to_ids, id_to_classes, rel_class_to_ids
18331837
"rel_id_to_classes": rel_id_to_classes,
18341838
}
18351839

1836-
def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=(1.0, 10.0)):
1840+
def create_relation_labels(
1841+
self, batch, add_reversed_negatives=True, add_random_negatives=True, negative_ratio=(1.0, 10.0)
1842+
):
18371843
"""Create relation labels with negative pair sampling.
18381844
18391845
Overrides the span-based version to work with token-level entity representations.
@@ -1870,7 +1876,7 @@ def create_relation_labels(self, batch, add_reversed_negatives=True, add_random_
18701876

18711877
# Batch CPU transfer to avoid per-element .item() sync
18721878
batch_ents_cpu = batch_ents.tolist()
1873-
max_En = max(max(batch_ents_cpu), 1)
1879+
max_En = max(*batch_ents_cpu, 1)
18741880

18751881
rel_class_to_ids = batch["rel_class_to_ids"]
18761882
if isinstance(rel_class_to_ids, list):

gliner/decoding/decoder.py

Lines changed: 33 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Span:
2020
class_probs: Optional dict of top-k class probabilities
2121
generated_labels: Optional list of generated labels (for generative decoders)
2222
"""
23+
2324
start: int
2425
end: int
2526
entity_type: str
@@ -260,7 +261,7 @@ def _decode_batch_item(
260261
"""
261262
# Mask probabilities to only include input spans (for efficiency)
262263
if input_spans_i is not None:
263-
L, K_dim, C = probs_i.shape
264+
L, K_dim, _ = probs_i.shape
264265
span_filter = torch.zeros(L, K_dim, dtype=torch.bool, device=probs_i.device)
265266
for word_start, word_end in input_spans_i:
266267
width = word_end - word_start
@@ -358,18 +359,20 @@ class IDs to class names.
358359
if B == 1:
359360
id_to_class_0 = self._get_id_to_class_for_sample(id_to_classes, 0)
360361
input_spans_0 = input_spans[0] if input_spans is not None else None
361-
return [self._decode_batch_item(
362-
probs_i=probs[0],
363-
tokens_i=tokens[0],
364-
id_to_class_i=id_to_class_0,
365-
K=K,
366-
threshold=threshold,
367-
flat_ner=flat_ner,
368-
multi_label=multi_label,
369-
span_label_map=span_label_maps[0],
370-
return_class_probs=return_class_probs,
371-
input_spans_i=input_spans_0,
372-
)]
362+
return [
363+
self._decode_batch_item(
364+
probs_i=probs[0],
365+
tokens_i=tokens[0],
366+
id_to_class_i=id_to_class_0,
367+
K=K,
368+
threshold=threshold,
369+
flat_ner=flat_ner,
370+
multi_label=multi_label,
371+
span_label_map=span_label_maps[0],
372+
return_class_probs=return_class_probs,
373+
input_spans_i=input_spans_0,
374+
)
375+
]
373376

374377
# Apply input_spans mask at batch level (one mask, one multiply)
375378
if input_spans is not None:
@@ -392,9 +395,7 @@ class IDs to class names.
392395
return [[] for _ in range(B)]
393396

394397
# ONE vectorized valid-span check across entire batch
395-
num_tokens = torch.tensor(
396-
[len(t) for t in tokens], device=probs.device, dtype=torch.long
397-
)
398+
num_tokens = torch.tensor([len(t) for t in tokens], device=probs.device, dtype=torch.long)
398399
valid = (s_idx + k_idx + 1) <= num_tokens[b_idx]
399400
b_idx = b_idx[valid]
400401
s_idx = s_idx[valid]
@@ -427,15 +428,11 @@ class IDs to class names.
427428
top_indices_list = all_top_indices.tolist()
428429

429430
# Pre-resolve id_to_class mappings per batch item
430-
id_to_class_per_item = [
431-
self._get_id_to_class_for_sample(id_to_classes, i) for i in range(B)
432-
]
431+
id_to_class_per_item = [self._get_id_to_class_for_sample(id_to_classes, i) for i in range(B)]
433432

434433
# Group by batch item and build Span objects (pure Python)
435434
batch_spans: List[List[Span]] = [[] for _ in range(B)]
436-
for j, (b, s, k, c, flat_idx, score) in enumerate(
437-
zip(b_list, s_list, k_list, c_list, flat_idxs, scores)
438-
):
435+
for j, (b, s, k, c, flat_idx, score) in enumerate(zip(b_list, s_list, k_list, c_list, flat_idxs, scores)):
439436
id_to_class_i = id_to_class_per_item[b]
440437

441438
class_probs = None
@@ -445,16 +442,11 @@ class IDs to class names.
445442
class_name = id_to_class_i.get(idx + 1, f"class_{idx}")
446443
class_probs[class_name] = prob
447444

448-
span = self._build_span_tuple(
449-
s, k, c, flat_idx, score, id_to_class_i, span_label_maps[b], class_probs
450-
)
445+
span = self._build_span_tuple(s, k, c, flat_idx, score, id_to_class_i, span_label_maps[b], class_probs)
451446
batch_spans[b].append(span)
452447

453448
# Per-item greedy search (inherently sequential, but cheap pure Python)
454-
return [
455-
self.greedy_search(spans, flat_ner, multi_label=multi_label)
456-
for spans in batch_spans
457-
]
449+
return [self.greedy_search(spans, flat_ner, multi_label=multi_label) for spans in batch_spans]
458450

459451
def decode(
460452
self,
@@ -544,13 +536,7 @@ def _build_span_tuple(
544536
Span: Span object with entity properties.
545537
"""
546538
ent_type = id_to_class[class_idx + 1] # +1 because 0 is <pad>
547-
return Span(
548-
start=start,
549-
end=start + width,
550-
entity_type=ent_type,
551-
score=score,
552-
class_probs=class_probs
553-
)
539+
return Span(start=start, end=start + width, entity_type=ent_type, score=score, class_probs=class_probs)
554540

555541

556542
class SpanGenerativeDecoder(BaseSpanDecoder):
@@ -679,7 +665,7 @@ def _build_span_tuple(
679665
entity_type=ent_type,
680666
score=score,
681667
class_probs=class_probs,
682-
generated_labels=gen_ent_type
668+
generated_labels=gen_ent_type,
683669
)
684670

685671
def decode_generative(
@@ -864,15 +850,8 @@ def _decode_relations_batch(
864850
# 3. Vectorized index-validity check
865851
head = rel_idx[..., 0] # (B, R)
866852
tail = rel_idx[..., 1] # (B, R)
867-
num_spans = torch.tensor(
868-
[len(s) for s in spans], device=rel_idx.device, dtype=head.dtype
869-
) # (B,)
870-
valid = (
871-
(head >= 0)
872-
& (tail >= 0)
873-
& (head < num_spans[:, None])
874-
& (tail < num_spans[:, None])
875-
) # (B, R)
853+
num_spans = torch.tensor([len(s) for s in spans], device=rel_idx.device, dtype=head.dtype) # (B,)
854+
valid = (head >= 0) & (tail >= 0) & (head < num_spans[:, None]) & (tail < num_spans[:, None]) # (B, R)
876855
rel_probs = rel_probs * valid.unsqueeze(-1)
877856

878857
# 4. Single torch.where on the full (B, R, C) tensor
@@ -898,9 +877,7 @@ def _decode_relations_batch(
898877
mapping = rel_id_to_classes[b] if is_list else rel_id_to_classes
899878
if c1 not in mapping:
900879
continue
901-
relations[b].append(
902-
(int(head_list[k]), mapping[c1], int(tail_list[k]), scores[k])
903-
)
880+
relations[b].append((int(head_list[k]), mapping[c1], int(tail_list[k]), scores[k]))
904881

905882
return relations
906883

@@ -955,13 +932,7 @@ def _build_span_tuple(
955932
Span: Span object with entity properties.
956933
"""
957934
ent_type = id_to_class[class_idx + 1] # +1 because 0 is <pad>
958-
return Span(
959-
start=start,
960-
end=start + width,
961-
entity_type=ent_type,
962-
score=score,
963-
class_probs=class_probs
964-
)
935+
return Span(start=start, end=start + width, entity_type=ent_type, score=score, class_probs=class_probs)
965936

966937
def _build_entity_span_to_decoded_idx(
967938
self,
@@ -1151,6 +1122,7 @@ def decode(
11511122
rel_idx: Optional tensor of shape (batch_size, num_relations, 2).
11521123
rel_logits: Optional tensor of shape (batch_size, num_relations, num_relation_classes).
11531124
rel_mask: Optional boolean tensor of shape (batch_size, num_relations).
1125+
return_class_probs: Whether to include class probabilities in the decoded spans.
11541126
flat_ner: If True, applies greedy filtering for non-overlapping entities.
11551127
threshold: Minimum confidence score for entity predictions.
11561128
relation_threshold: Minimum confidence score for relation predictions.
@@ -1266,13 +1238,8 @@ def _calculate_span_score(
12661238
start_score = start_cpu[st][cls_st]
12671239
end_score = end_cpu[ed][cls_ed]
12681240
# The span score is the minimum value among all scores
1269-
spn_score = min(min(ins), start_score, end_score)
1270-
span_i.append(Span(
1271-
start=st,
1272-
end=ed,
1273-
entity_type=id_to_classes[cls_st + 1],
1274-
score=spn_score
1275-
))
1241+
spn_score = min(*ins, start_score, end_score)
1242+
span_i.append(Span(start=st, end=ed, entity_type=id_to_classes[cls_st + 1], score=spn_score))
12761243
return span_i
12771244

12781245
def _decode_from_spans(
@@ -1349,12 +1316,7 @@ class IDs to class names.
13491316
class_id = class_idx + 1 # Convert to 1-indexed
13501317
if class_id in id_to_class_i:
13511318
entity_type = id_to_class_i[class_id]
1352-
span_scores.append(Span(
1353-
start=span_start,
1354-
end=span_end,
1355-
entity_type=entity_type,
1356-
score=prob
1357-
))
1319+
span_scores.append(Span(start=span_start, end=span_end, entity_type=entity_type, score=prob))
13581320

13591321
# Apply greedy search to handle overlapping spans if needed
13601322
span_i = self.greedy_search(span_scores, flat_ner, multi_label)
@@ -1664,6 +1626,8 @@ def decode(
16641626
rel_id_to_classes: Optional mapping from relation class IDs to relation names.
16651627
If None, relation decoding is skipped and empty relation lists are returned.
16661628
Can be either a single Dict or List[Dict] for per-sample mappings.
1629+
entity_spans: Optional tensor of pre-computed entity spans to use instead
1630+
of decoding them from model_output.
16671631
Class IDs are 1-indexed.
16681632
**kwargs: Additional keyword arguments passed to the parent class decode method.
16691633

0 commit comments

Comments
 (0)