Skip to content

Commit 595cc41

Browse files
committed
relation extraction prototype
1 parent 7cb79f8 commit 595cc41

8 files changed

Lines changed: 397 additions & 157 deletions

File tree

docs/pipeline.rst

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,22 @@ For simplicity, one can use one of the preconfigured pipelines:
7575

7676
.. code-block:: python
7777
78-
from renard.pipeline.preconfigured import bert_pipeline
78+
from renard.pipeline.preconfigured import co_occurence_pipeline
7979
8080
with open("./my_doc.txt") as f:
8181
text = f.read()
8282
83-
pipeline = bert_pipeline(
84-
graph_extractor_kwargs={"co_occurrences_dist": (1, "sentences")}
85-
)
83+
pipeline = co_occurrence_pipeline()
8684
out = pipeline(text)
8785
8886
87+
The following preconfigured pipelines are available:
88+
89+
- :func:`.co_occurrence_pipeline`
90+
- :func:`.conversational_pipeline`
91+
- :func:`.relational_pipeline`
92+
93+
8994
Pipeline Output: the Pipeline State
9095
===================================
9196

@@ -137,7 +142,7 @@ Tokenization
137142
Tokenization is the task of cutting text in *tokens*. It is usually
138143
the first task to apply to a text. 2 tokenizer are available:
139144

140-
- :class:`.NLTKTokenizer`
145+
- :class:`.NLTKTokenizer` is the tokenizer from NLTK.
141146
- :class:`.StanfordCoreNLPPipeline` does contain a tokenizer as part
142147
of its full NLP pipeline.
143148

@@ -148,16 +153,19 @@ Named Entity Recognition
148153
Named entity recognition (NER) detects entities occurences in the
149154
text. 3 modules are available:
150155

151-
- :class:`.NLTKNamedEntityRecognizer`
152-
- :class:`.BertNamedEntityRecognizer`
156+
- :class:`.NLTKNamedEntityRecognizer` is a lightweight NER module from
157+
NLTK, based on POS tagging and rules.
158+
- :class:`.BertNamedEntityRecognizer` is a NER module employing a
159+
finetuned BERT model.
153160
- :class:`.StanfordCoreNLPPipeline` contains a NER model as part of
154161
its full NLP pipeline.
155162

156163

157164
Coreference Resolution
158165
----------------------
159166

160-
- :class:`.SpacyCorefereeCoreferenceResolver`
167+
- :class:`.SpacyCorefereeCoreferenceResolver` uses the spacy coreferee
168+
module.
161169
- :class:`.BertCoreferenceResolver`, using the Tibert library.
162170
- :class:`.StanfordCoreNLPPipeline` can execute a coreference
163171
resolution model as part of its pipeline.
@@ -166,14 +174,14 @@ Coreference Resolution
166174
Quote Detection
167175
---------------
168176

169-
- :class:`.QuoteDetector`
177+
- :class:`.QuoteDetector` detect quotes using simple logic.
170178

171179

172180
Sentiment Analysis
173181
------------------
174182

175183
- :class:`.NLTKSentimentAnalyzer` leverages NLTK's Vader for sentiment
176-
analysis
184+
analysis.
177185

178186

179187
Characters Extraction
@@ -183,21 +191,36 @@ Characters extraction (or alias resolution) extract characters from
183191
occurences detected using NER. This is done by assigning each mention
184192
to a unique character.
185193

186-
- :class:`.NaiveCharacterUnifier`
187-
- :class:`.GraphRulesCharacterUnifier`
194+
- :class:`.NaiveCharacterUnifier` assigns each mention with a unique
195+
form to a character.
196+
- :class:`.GraphRulesCharacterUnifier` uses a set of rules to assign
197+
each mention to a character.
198+
199+
200+
Relation Extraction
201+
-------------------
202+
203+
- :class:`.T5RelationExtractor` extracts relations between characters
204+
using a finetuned T5 model.
188205

189206

190207
Speaker Attribution
191208
-------------------
192209

193-
- :class:`.BertSpeakerDetector`
210+
- :class:`.BertSpeakerDetector` detects speaker using a finetuned BERT
211+
model.
194212

195213

196214
Graph Extraction
197215
----------------
198216

199-
- :class:`.CoOccurrencesGraphExtractor`
200-
- :class:`.ConversationalGraphExtractor`
217+
- :class:`.CoOccurrencesGraphExtractor` extracts a graph of
218+
co-occurrence between characters.
219+
- :class:`.ConversationalGraphExtractor` extracts a conversational
220+
graph: either conversation between characters, or of character
221+
mentions.
222+
- :class:`.RelationalGraphExtractor` extracts a relational graph,
223+
where the relation between each character is typed.
201224

202225

203226
Dynamic Graphs
@@ -241,7 +264,9 @@ When executing the above block of code, the output attribute
241264
[<networkx.classes.graph.Graph object at 0x7fd9e9115900>]
242265

243266
See :class:`.CoOccurrencesGraphExtractor` for more details on the
244-
usage of the ``dynamic`` and ``dynamic_window`` arguments.
267+
usage of the ``dynamic`` and ``dynamic_window`` arguments. Note that,
268+
currently, only the co-occurrence graph extractor supports dynamic
269+
networks.
245270

246271
Plot and export functions work as one would expect
247272
intuitively. :meth:`.PipelineState.plot_graph` allow to visualize the

renard/pipeline/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from renard.pipeline.character_unification import Character
3636
from renard.pipeline.ner import NEREntity
3737
from renard.pipeline.quote_detection import Quote
38+
from renard.pipeline.relation_extraction import Relation
3839
import matplotlib.pyplot as plt
3940

4041

@@ -175,7 +176,10 @@ class PipelineState:
175176
speakers: Optional[List[Optional[Character]]] = None
176177

177178
#: polarity of each sentence
178-
sentences_polarities: Optional[List[float]] = None
179+
sentence_polarities: Optional[List[float]] = None
180+
181+
#: relations detected in each sentence
182+
sentence_relations: Optional[List[List[Relation]]] = None
179183

180184
#: NER entities
181185
entities: Optional[List[NEREntity]] = None

renard/pipeline/graph_extraction.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Any, List, Set, Optional, Tuple, Literal, Union
22
import itertools as it
3+
from collections import defaultdict
34
import operator
45

56
import networkx as nx
@@ -11,6 +12,7 @@
1112
from renard.pipeline.core import PipelineStep
1213
from renard.pipeline.character_unification import Character
1314
from renard.pipeline.quote_detection import Quote
15+
from renard.pipeline.relation_extraction import Relation
1416

1517

1618
def sent_index_for_token_index(token_index: int, sentences: List[List[str]]) -> int:
@@ -147,7 +149,7 @@ def __call__(
147149
sentences: List[List[str]],
148150
char2token: Optional[List[int]] = None,
149151
dynamic_blocks: Optional[BlockBounds] = None,
150-
sentences_polarities: Optional[List[float]] = None,
152+
sentence_polarities: Optional[List[float]] = None,
151153
entities: Optional[List[NEREntity]] = None,
152154
co_occurrences_blocks: Optional[BlockBounds] = None,
153155
**kwargs,
@@ -194,13 +196,13 @@ def __call__(
194196
self.dynamic_overlap,
195197
dynamic_blocks,
196198
sentences,
197-
sentences_polarities,
199+
sentence_polarities,
198200
co_occurrences_blocks,
199201
)
200202
}
201203
return {
202204
"character_network": self._extract_graph(
203-
mentions, sentences, sentences_polarities, co_occurrences_blocks
205+
mentions, sentences, sentence_polarities, co_occurrences_blocks
204206
)
205207
}
206208

@@ -257,24 +259,24 @@ def _extract_graph(
257259
self,
258260
mentions: List[Tuple[Any, NEREntity]],
259261
sentences: List[List[str]],
260-
sentences_polarities: Optional[List[float]],
262+
sentence_polarities: Optional[List[float]],
261263
co_occurrences_blocks: Optional[BlockBounds],
262264
) -> nx.Graph:
263265
"""
264266
:param mentions: A list of entity mentions, ordered by
265267
appearance, each of the form (KEY MENTION). KEY
266268
determines the unicity of the entity.
267-
:param sentences: if specified, ``sentences_polarities`` must
269+
:param sentences: if specified, ``sentence_polarities`` must
268270
be specified as well.
269-
:param sentences_polarities: if specified, ``sentences`` must
271+
:param sentence_polarities: if specified, ``sentences`` must
270272
be specified as well. In that case, edges are annotated
271273
with the ``'polarity`` attribute, indicating the polarity
272274
of the relationship between two characters. Polarity
273275
between two interactions is computed as the strongest
274276
sentence polarity between those two mentions.
275277
:param co_occurrences_blocks: only unit 'tokens' is accepted.
276278
"""
277-
compute_polarity = not sentences_polarities is None
279+
compute_polarity = not sentence_polarities is None
278280

279281
assert co_occurrences_blocks is None or co_occurrences_blocks[1] == "tokens"
280282
if co_occurrences_blocks is None:
@@ -324,15 +326,15 @@ def _extract_graph(
324326

325327
if compute_polarity:
326328
assert not sentences is None
327-
assert not sentences_polarities is None
329+
assert not sentence_polarities is None
328330
# TODO: optim
329331
first_sent_idx = sent_index_for_token_index(
330332
mention1.start_idx, sentences
331333
)
332334
last_sent_idx = sent_index_for_token_index(
333335
mention2.start_idx, sentences
334336
)
335-
sents_polarities_between_mentions = sentences_polarities[
337+
sents_polarities_between_mentions = sentence_polarities[
336338
first_sent_idx : last_sent_idx + 1
337339
]
338340
polarity = max(sents_polarities_between_mentions, key=abs)
@@ -349,7 +351,7 @@ def _extract_dynamic_graph(
349351
overlap: int,
350352
dynamic_blocks: Optional[BlockBounds],
351353
sentences: List[List[str]],
352-
sentences_polarities: Optional[List[float]],
354+
sentence_polarities: Optional[List[float]],
353355
co_occurrences_blocks: Optional[BlockBounds],
354356
) -> List[nx.Graph]:
355357
"""
@@ -367,14 +369,14 @@ def _extract_dynamic_graph(
367369
"""
368370
assert co_occurrences_blocks is None or co_occurrences_blocks[1] == "tokens"
369371
assert window is None or dynamic_blocks is None
370-
compute_polarity = not sentences is None and not sentences_polarities is None
372+
compute_polarity = not sentences is None and not sentence_polarities is None
371373

372374
if not window is None:
373375
return [
374376
self._extract_graph(
375377
[elt for elt in ct if not elt is None],
376378
sentences,
377-
sentences_polarities,
379+
sentence_polarities,
378380
co_occurrences_blocks,
379381
)
380382
for ct in windowed(mentions, window, step=window - overlap)
@@ -391,10 +393,10 @@ def _extract_dynamic_graph(
391393
sent_start, sent_end = sent_indices_for_block(dynamic_block, sentences)
392394
block_sentences = sentences[sent_start : sent_end + 1]
393395

394-
block_sentences_polarities = None
396+
block_sentence_polarities = None
395397
if compute_polarity:
396-
assert not sentences_polarities is None
397-
block_sentences_polarities = sentences_polarities[
398+
assert not sentence_polarities is None
399+
block_sentence_polarities = sentence_polarities[
398400
sent_start : sent_end + 1
399401
]
400402

@@ -412,7 +414,7 @@ def _extract_dynamic_graph(
412414
self._extract_graph(
413415
block_mentions,
414416
block_sentences,
415-
block_sentences_polarities,
417+
block_sentence_polarities,
416418
block_co_occ_bounds,
417419
)
418420
)
@@ -441,7 +443,7 @@ def production(self) -> Set[str]:
441443
return {"character_network"}
442444

443445
def optional_needs(self) -> Set[str]:
444-
return {"sentences_polarities"}
446+
return {"sentence_polarities"}
445447

446448

447449
class ConversationalGraphExtractor(PipelineStep):
@@ -588,7 +590,6 @@ def __call__(
588590
characters: Set[Character],
589591
**kwargs,
590592
) -> Dict[str, Any]:
591-
592593
if self.graph_type == "conversation":
593594
G = self._conversation_extract(sentences, quotes, speakers, characters)
594595
elif self.graph_type == "mention":
@@ -608,3 +609,52 @@ def needs(self) -> Set[str]:
608609
def production(self) -> Set[str]:
609610
"""character_network"""
610611
return {"character_network"}
612+
613+
614+
class RelationalGraphExtractor(PipelineStep):
615+
"""A graph extractor using relations between characters.
616+
617+
.. note::
618+
619+
Does not support dynamic networks yet.
620+
"""
621+
622+
def __init__(self, min_rel_occurrences: int = 1):
623+
self.min_rel_occurrences = min_rel_occurrences
624+
625+
def __call__(
626+
self,
627+
characters: list[Character],
628+
sentence_relations: list[list[Relation]],
629+
**kwargs,
630+
) -> dict[str, Any]:
631+
G = nx.Graph()
632+
for character in characters:
633+
G.add_node(character)
634+
635+
# { (char1, char2) => { relation: counter } }
636+
edge_relations = defaultdict(dict)
637+
for relations in sentence_relations:
638+
for subj, rel, obj in relations:
639+
counter = edge_relations[(subj, obj)].get(rel, 0)
640+
edge_relations[(subj, obj)][rel] = counter + 1
641+
642+
for (char1, char2), counter in edge_relations.items():
643+
relations = {
644+
rel
645+
for rel, count in counter.items()
646+
if count >= self.min_rel_occurrences
647+
}
648+
if len(relations) > 0:
649+
G.add_edge(char1, char2, relations=relations)
650+
651+
return {"character_network": G}
652+
653+
def supported_langs(self) -> Literal["any"]:
654+
return "any"
655+
656+
def needs(self) -> set[str]:
657+
return {"characters", "sentence_relations"}
658+
659+
def production(self) -> set[str]:
660+
return {"character_network"}

0 commit comments

Comments
 (0)