Skip to content

Commit b2d65c0

Browse files
committed
ConversationalGraphExtractor now supports dynamic networks
1 parent bc38e6f commit b2d65c0

3 files changed

Lines changed: 154 additions & 26 deletions

File tree

docs/pipeline.rst

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,9 @@ When executing the above block of code, the output attribute
263263
>>> out.character_network
264264
[<networkx.classes.graph.Graph object at 0x7fd9e9115900>]
265265

266-
See :class:`.CoOccurrencesGraphExtractor` for more details on the
267-
usage of the ``dynamic`` and ``dynamic_window`` arguments. Note that,
268-
currently, only the co-occurrence graph extractor supports dynamic
269-
networks.
266+
Both :class:`.CoOccurrencesGraphExtractor` and
267+
:class:`.ConversationalGraphExtractor` support dynamic networks. See
268+
their documentation for more details.
270269

271270
Plot and export functions work as one would expect
272271
intuitively. :meth:`.PipelineState.plot_graph` allow to visualize the
@@ -280,10 +279,9 @@ dynamic graph to the Gephi format.
280279
Custom Segmentation
281280
-------------------
282281

283-
The ``dynamic_window`` parameter of
284-
:class:`.CoOccurencesGraphExtractor` determines the segmentation of
285-
the dynamic networks, in number of interactions. In the example above,
286-
a new graph will be created for each 20 interactions.
282+
The ``dynamic_window`` parameter determines the segmentation of the
283+
dynamic networks, in number of interactions. In the example above, a
284+
new graph will be created for each 20 interactions.
287285

288286
While one can rely on the arguments of the graph extractor of the
289287
pipeline to determine the dynamic window, Renard allows to specify a

renard/pipeline/graph_extraction.py

Lines changed: 112 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,32 @@ def mentions_for_blocks(
7676
return blocks_mentions
7777

7878

79+
def quotes_for_blocks(
80+
block_bounds: BlockBounds, quotes: List[Quote], speakers: List[Optional[Character]]
81+
) -> Tuple[List[List[Quote]], List[List[Optional[Character]]]]:
82+
"""Return quotes and associated speaker for each block.
83+
84+
:param block_bounds: block bounds, in tokens
85+
:param mentions: a sorted list of mentions
86+
87+
:return: a list of quotes per block. This list has len
88+
``len(block_bounds)``.
89+
"""
90+
assert block_bounds[1] == "tokens"
91+
92+
block_quotes = [[] for _ in range(len(block_bounds[0]))]
93+
block_speakers = [[] for _ in range(len(block_bounds[0]))]
94+
95+
for quote, speaker in zip(quotes, speakers):
96+
for block_i, (start_i, end_i) in enumerate(block_bounds[0]):
97+
if quote.start >= start_i and quote.end < end_i:
98+
block_quotes[block_i].append(quote)
99+
block_speakers[block_i].append(speaker)
100+
break
101+
102+
return block_quotes, block_speakers
103+
104+
79105
class CoOccurrencesGraphExtractor(PipelineStep):
80106
"""A simple character graph extractor using co-occurences"""
81107

@@ -449,10 +475,6 @@ def optional_needs(self) -> Set[str]:
449475
class ConversationalGraphExtractor(PipelineStep):
450476
"""A graph extractor using conversation between characters or
451477
mentions.
452-
453-
.. note::
454-
455-
Does not support dynamic networks yet.
456478
"""
457479

458480
def __init__(
@@ -462,6 +484,9 @@ def __init__(
462484
Union[int, Tuple[int, Literal["tokens", "sentences"]]]
463485
] = None,
464486
ignore_self_mention: bool = True,
487+
dynamic: bool = False,
488+
dynamic_window: Optional[int] = None,
489+
dynamic_overlap: int = 0,
465490
):
466491
"""
467492
:param graph_type: either 'conversation' or 'mention'.
@@ -470,11 +495,31 @@ def __init__(
470495
occurring between characters. 'mention' extracts a
471496
directed graph where interactions are character mentions
472497
of one another in quoted speech.
498+
473499
:param conversation_dist: must be supplied if `graph_type` is
474500
'conversation'. The distance between two quotation for
475501
them to be considered as being interacting.
502+
476503
:param ignore_self_mention: if ``True``, self mentions are
477-
ignore for ``graph_type=='mention'``
504+
ignored for ``graph_type=='mention'``
505+
506+
:param dynamic:
507+
508+
- if ``False`` (the default), a static ``nx.graph`` is
509+
extracted
510+
511+
- if ``True``, several ``nx.graph`` are extracted. In
512+
that case, ``dynamic_window`` and
513+
``dynamic_overlap``*can* be specified. If
514+
``dynamic_window`` is not specified, this step is
515+
expecting the text to be cut into 'dynamic blocks',
516+
and a graph will be extracted for each block. In
517+
that case, ``dynamic_blocks`` must be passed to the
518+
pipeline as a ``List[str]`` at runtime.
519+
520+
:param dynamic_window: dynamic window, in number of quotes.
521+
522+
:param dynamic_overlap: overlap, in number of quotes.
478523
"""
479524
self.graph_type = graph_type
480525

@@ -484,6 +529,10 @@ def __init__(
484529

485530
self.ignore_self_mention = ignore_self_mention
486531

532+
self.dynamic = dynamic
533+
self.dynamic_window = dynamic_window
534+
self.dynamic_overlap = dynamic_overlap
535+
487536
super().__init__()
488537

489538
def _quotes_interact(
@@ -564,12 +613,12 @@ def _mention_extract(
564613
if speaker is None:
565614
continue
566615

567-
# TODO: optim
568616
# find characters mentioned in quote and add a directed
569617
# edge speaker => character
570618
for character in characters:
571619
if character == speaker and self.ignore_self_mention:
572620
continue
621+
# TODO: optim
573622
for mention in character.mentions:
574623
if (
575624
mention.start_idx >= quote.start
@@ -582,22 +631,75 @@ def _mention_extract(
582631

583632
return G
584633

585-
def __call__(
634+
def _extract_static(
586635
self,
587636
sentences: List[List[str]],
588637
quotes: List[Quote],
589638
speakers: List[Optional[Character]],
590639
characters: Set[Character],
591-
**kwargs,
592-
) -> Dict[str, Any]:
640+
) -> nx.Graph:
593641
if self.graph_type == "conversation":
594642
G = self._conversation_extract(sentences, quotes, speakers, characters)
595643
elif self.graph_type == "mention":
596644
G = self._mention_extract(quotes, speakers, characters)
597645
else:
598646
raise ValueError(f"unknown graph_type: {self.graph_type}")
647+
return G
599648

600-
return {"character_network": G}
649+
def _extract_dynamic(
650+
self,
651+
sentences: List[List[str]],
652+
quotes: List[Quote],
653+
speakers: List[Optional[Character]],
654+
characters: Set[Character],
655+
dynamic_blocks: Optional[BlockBounds] = None,
656+
) -> List[nx.Graph]:
657+
assert self.dynamic_window is None or dynamic_blocks is None
658+
659+
if not self.dynamic_window is None:
660+
bounds = []
661+
for block_quotes in windowed(
662+
quotes,
663+
self.dynamic_window,
664+
step=self.dynamic_window - self.dynamic_overlap,
665+
):
666+
block_quotes = [q for q in block_quotes if not q is None]
667+
bounds.append((block_quotes[0].start, block_quotes[0].end))
668+
dynamic_blocks = (bounds, "tokens")
669+
670+
assert not dynamic_blocks is None
671+
672+
quotes_for_each_block, speakers_for_each_block = quotes_for_blocks(
673+
dynamic_blocks, quotes, speakers
674+
)
675+
return [
676+
self._extract_static(sentences, block_quotes, block_speakers, characters)
677+
for block_quotes, block_speakers in zip(
678+
quotes_for_each_block, speakers_for_each_block
679+
)
680+
]
681+
682+
def __call__(
683+
self,
684+
sentences: List[List[str]],
685+
quotes: List[Quote],
686+
speakers: List[Optional[Character]],
687+
characters: Set[Character],
688+
dynamic_blocks: Optional[BlockBounds] = None,
689+
**kwargs,
690+
) -> Dict[str, Union[nx.Graph, List[nx.Graph]]]:
691+
if self.dynamic:
692+
return {
693+
"character_network": self._extract_dynamic(
694+
sentences, quotes, speakers, characters, dynamic_blocks
695+
)
696+
}
697+
else:
698+
return {
699+
"character_network": self._extract_static(
700+
sentences, quotes, speakers, characters
701+
)
702+
}
601703

602704
def supported_langs(self) -> Literal["any"]:
603705
return "any"

tests/test_graph_extraction.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from collections import defaultdict
2-
from typing import List
2+
from typing import List, Literal
33
import itertools, string
44
from hypothesis import given
5-
from hypothesis.strategies import lists, sampled_from
5+
from hypothesis.strategies import lists, sampled_from, one_of, just
66
from hypothesis.strategies._internal.numbers import integers
77
import networkx as nx
88
from networkx.algorithms import isomorphism
9-
from renard.pipeline.graph_extraction import CoOccurrencesGraphExtractor
9+
from renard.pipeline.graph_extraction import (
10+
CoOccurrencesGraphExtractor,
11+
ConversationalGraphExtractor,
12+
)
1013
from renard.pipeline.character_unification import Character
14+
from renard.pipeline.speaker_attribution import Quote
1115
from renard.pipeline.ner import ner_entities, NEREntity
1216

1317

@@ -67,11 +71,6 @@ def test_basic_graph_extraction(tokens: List[str]):
6771
def test_dynamic_co_occurrences_graph_extraction(
6872
tokens: List[str], dynamic_window: int
6973
):
70-
"""
71-
.. note::
72-
73-
only tests execution.
74-
"""
7574
bio_tags = ["B-PER" for _ in tokens]
7675

7776
mentions = ner_entities(tokens, bio_tags)
@@ -82,6 +81,35 @@ def test_dynamic_co_occurrences_graph_extraction(
8281
)
8382
out = graph_extractor(set(characters), [tokens])
8483

84+
assert isinstance(out["character_network"], list)
85+
assert len(out["character_network"]) > 0
86+
87+
88+
@given(
89+
one_of(just("conversation"), just("mention")),
90+
lists(sampled_from(string.ascii_uppercase), min_size=1),
91+
integers(min_value=1, max_value=5),
92+
)
93+
def test_dynamic_conversational_graph_extraction(
94+
graph_type: Literal["conversation", "mention"],
95+
tokens: List[str],
96+
dynamic_window: int,
97+
):
98+
mentions = ner_entities(tokens, ["B-PER" for _ in tokens])
99+
characters = _characters_from_mentions(mentions)
100+
101+
quotes = [Quote(i, i + 1, tokens[i : i + 1]) for i in range(len(tokens))]
102+
speakers = [characters[0] for _ in quotes]
103+
104+
graph_extractor = ConversationalGraphExtractor(
105+
graph_type,
106+
dynamic=True,
107+
dynamic_window=dynamic_window,
108+
ignore_self_mention=False,
109+
)
110+
out = graph_extractor([tokens], quotes, speakers, set(characters))
111+
112+
assert isinstance(out["character_network"], list)
85113
assert len(out["character_network"]) > 0
86114

87115

0 commit comments

Comments
 (0)