11from typing import Dict , Any , List , Set , Optional , Tuple , Literal , Union
22import itertools as it
3+ from collections import defaultdict
34import operator
45
56import networkx as nx
1112from renard .pipeline .core import PipelineStep
1213from renard .pipeline .character_unification import Character
1314from renard .pipeline .quote_detection import Quote
15+ from renard .pipeline .relation_extraction import Relation
1416
1517
1618def sent_index_for_token_index (token_index : int , sentences : List [List [str ]]) -> int :
@@ -147,7 +149,7 @@ def __call__(
147149 sentences : List [List [str ]],
148150 char2token : Optional [List [int ]] = None ,
149151 dynamic_blocks : Optional [BlockBounds ] = None ,
150- sentences_polarities : Optional [List [float ]] = None ,
152+ sentence_polarities : Optional [List [float ]] = None ,
151153 entities : Optional [List [NEREntity ]] = None ,
152154 co_occurrences_blocks : Optional [BlockBounds ] = None ,
153155 ** kwargs ,
@@ -194,13 +196,13 @@ def __call__(
194196 self .dynamic_overlap ,
195197 dynamic_blocks ,
196198 sentences ,
197- sentences_polarities ,
199+ sentence_polarities ,
198200 co_occurrences_blocks ,
199201 )
200202 }
201203 return {
202204 "character_network" : self ._extract_graph (
203- mentions , sentences , sentences_polarities , co_occurrences_blocks
205+ mentions , sentences , sentence_polarities , co_occurrences_blocks
204206 )
205207 }
206208
@@ -257,24 +259,24 @@ def _extract_graph(
257259 self ,
258260 mentions : List [Tuple [Any , NEREntity ]],
259261 sentences : List [List [str ]],
260- sentences_polarities : Optional [List [float ]],
262+ sentence_polarities : Optional [List [float ]],
261263 co_occurrences_blocks : Optional [BlockBounds ],
262264 ) -> nx .Graph :
263265 """
264266 :param mentions: A list of entity mentions, ordered by
265267 appearance, each of the form (KEY MENTION). KEY
266268 determines the unicity of the entity.
267- :param sentences: if specified, ``sentences_polarities `` must
269+ :param sentences: if specified, ``sentence_polarities `` must
268270 be specified as well.
269- :param sentences_polarities : if specified, ``sentences`` must
271+ :param sentence_polarities : if specified, ``sentences`` must
270272 be specified as well. In that case, edges are annotated
271273 with the ``'polarity`` attribute, indicating the polarity
272274 of the relationship between two characters. Polarity
273275 between two interactions is computed as the strongest
274276 sentence polarity between those two mentions.
275277 :param co_occurrences_blocks: only unit 'tokens' is accepted.
276278 """
277- compute_polarity = not sentences_polarities is None
279+ compute_polarity = not sentence_polarities is None
278280
279281 assert co_occurrences_blocks is None or co_occurrences_blocks [1 ] == "tokens"
280282 if co_occurrences_blocks is None :
@@ -324,15 +326,15 @@ def _extract_graph(
324326
325327 if compute_polarity :
326328 assert not sentences is None
327- assert not sentences_polarities is None
329+ assert not sentence_polarities is None
328330 # TODO: optim
329331 first_sent_idx = sent_index_for_token_index (
330332 mention1 .start_idx , sentences
331333 )
332334 last_sent_idx = sent_index_for_token_index (
333335 mention2 .start_idx , sentences
334336 )
335- sents_polarities_between_mentions = sentences_polarities [
337+ sents_polarities_between_mentions = sentence_polarities [
336338 first_sent_idx : last_sent_idx + 1
337339 ]
338340 polarity = max (sents_polarities_between_mentions , key = abs )
@@ -349,7 +351,7 @@ def _extract_dynamic_graph(
349351 overlap : int ,
350352 dynamic_blocks : Optional [BlockBounds ],
351353 sentences : List [List [str ]],
352- sentences_polarities : Optional [List [float ]],
354+ sentence_polarities : Optional [List [float ]],
353355 co_occurrences_blocks : Optional [BlockBounds ],
354356 ) -> List [nx .Graph ]:
355357 """
@@ -367,14 +369,14 @@ def _extract_dynamic_graph(
367369 """
368370 assert co_occurrences_blocks is None or co_occurrences_blocks [1 ] == "tokens"
369371 assert window is None or dynamic_blocks is None
370- compute_polarity = not sentences is None and not sentences_polarities is None
372+ compute_polarity = not sentences is None and not sentence_polarities is None
371373
372374 if not window is None :
373375 return [
374376 self ._extract_graph (
375377 [elt for elt in ct if not elt is None ],
376378 sentences ,
377- sentences_polarities ,
379+ sentence_polarities ,
378380 co_occurrences_blocks ,
379381 )
380382 for ct in windowed (mentions , window , step = window - overlap )
@@ -391,10 +393,10 @@ def _extract_dynamic_graph(
391393 sent_start , sent_end = sent_indices_for_block (dynamic_block , sentences )
392394 block_sentences = sentences [sent_start : sent_end + 1 ]
393395
394- block_sentences_polarities = None
396+ block_sentence_polarities = None
395397 if compute_polarity :
396- assert not sentences_polarities is None
397- block_sentences_polarities = sentences_polarities [
398+ assert not sentence_polarities is None
399+ block_sentence_polarities = sentence_polarities [
398400 sent_start : sent_end + 1
399401 ]
400402
@@ -412,7 +414,7 @@ def _extract_dynamic_graph(
412414 self ._extract_graph (
413415 block_mentions ,
414416 block_sentences ,
415- block_sentences_polarities ,
417+ block_sentence_polarities ,
416418 block_co_occ_bounds ,
417419 )
418420 )
@@ -441,7 +443,7 @@ def production(self) -> Set[str]:
441443 return {"character_network" }
442444
443445 def optional_needs (self ) -> Set [str ]:
444- return {"sentences_polarities " }
446+ return {"sentence_polarities " }
445447
446448
447449class ConversationalGraphExtractor (PipelineStep ):
@@ -588,7 +590,6 @@ def __call__(
588590 characters : Set [Character ],
589591 ** kwargs ,
590592 ) -> Dict [str , Any ]:
591-
592593 if self .graph_type == "conversation" :
593594 G = self ._conversation_extract (sentences , quotes , speakers , characters )
594595 elif self .graph_type == "mention" :
@@ -608,3 +609,52 @@ def needs(self) -> Set[str]:
608609 def production (self ) -> Set [str ]:
609610 """character_network"""
610611 return {"character_network" }
612+
613+
614+ class RelationalGraphExtractor (PipelineStep ):
615+ """A graph extractor using relations between characters.
616+
617+ .. note::
618+
619+ Does not support dynamic networks yet.
620+ """
621+
622+ def __init__ (self , min_rel_occurrences : int = 1 ):
623+ self .min_rel_occurrences = min_rel_occurrences
624+
625+ def __call__ (
626+ self ,
627+ characters : list [Character ],
628+ sentence_relations : list [list [Relation ]],
629+ ** kwargs ,
630+ ) -> dict [str , Any ]:
631+ G = nx .Graph ()
632+ for character in characters :
633+ G .add_node (character )
634+
635+ # { (char1, char2) => { relation: counter } }
636+ edge_relations = defaultdict (dict )
637+ for relations in sentence_relations :
638+ for subj , rel , obj in relations :
639+ counter = edge_relations [(subj , obj )].get (rel , 0 )
640+ edge_relations [(subj , obj )][rel ] = counter + 1
641+
642+ for (char1 , char2 ), counter in edge_relations .items ():
643+ relations = {
644+ rel
645+ for rel , count in counter .items ()
646+ if count >= self .min_rel_occurrences
647+ }
648+ if len (relations ) > 0 :
649+ G .add_edge (char1 , char2 , relations = relations )
650+
651+ return {"character_network" : G }
652+
653+ def supported_langs (self ) -> Literal ["any" ]:
654+ return "any"
655+
656+ def needs (self ) -> set [str ]:
657+ return {"characters" , "sentence_relations" }
658+
659+ def production (self ) -> set [str ]:
660+ return {"character_network" }
0 commit comments