Skip to content

Commit 658e60d

Browse files
committed
New id-keyed API for set_node_styles and set_mutation_styles
1 parent cadccad commit 658e60d

1 file changed

Lines changed: 39 additions & 10 deletions

File tree

tskit_arg_visualizer/__init__.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -635,22 +635,51 @@ def set_all_node_styles(self, size=None, symbol=None, fill=None, stroke=None, st
635635
self.nodes["stroke"] = stroke
636636
if stroke_width != None:
637637
self.nodes["stroke_width"] = stroke_width
638-
638+
639+
@staticmethod
640+
def _set_styles(dataframe, styles, allowed_keys):
641+
# NB: this is very inefficient for lots of nodes. If you want to style e.g.
642+
# a set of nodes, it is quicker to use something like
643+
# d3arg.nodes.loc[np.isin(self.nodes["id"], node_ids), "size"] = size
644+
for item_id, style in styles.items():
645+
if 'id' in dataframe.columns:
646+
use = dataframe["id"] == item_id
647+
else:
648+
use = item_id
649+
for k in style.keys():
650+
if k not in allowed_keys:
651+
raise ValueError(
652+
f"Invalid key '{k}' in styles. Allowed keys are {allowed_keys}.")
653+
dataframe.loc[use, list(style.keys())] = list(style.values())
654+
639655
def set_node_styles(self, styles):
640656
"""Individually control the styling of each node.
641657
642658
Parameters
643659
----------
644-
styles : list
645-
List of dicts, one per node, with the styling keys: id, size, symbol, fill, stroke, stroke_width.
646-
"id" is the only mandatory key. Only nodes that need styles updated need to be provided.
660+
styles : dict
661+
A dictionary whose keys are node ids, and whose values are separate dicts
662+
containing any of the styling keys: size, symbol, fill, stroke, stroke_width.
663+
Only nodes that need styles updated need to be provided, and all styling
664+
keys are optional.
647665
"""
666+
allowed_keys = {"size", "symbol", "fill", "stroke", "stroke_width"}
667+
self._set_styles(self.nodes, styles, allowed_keys)
668+
669+
def set_mutation_styles(self, styles):
670+
"""Individually control the styling of each mutation.
671+
672+
Parameters
673+
----------
674+
styles : dict
675+
A dictionary whose keys are mutation ids, and whose values are separate
676+
dictionaries containing any of the styling keys: size, fill, stroke.
677+
Only mutations that need styles updated need to be provided, and all
678+
styling keys are optional.
679+
"""
680+
allowed_keys = {"fill", "stroke", "size"}
681+
self._set_styles(self.mutations, styles, allowed_keys)
648682

649-
for node in styles:
650-
for key in node.keys():
651-
if key in ["size", "symbol", "fill", "stroke", "stroke_width"]:
652-
self.nodes.loc[self.nodes["id"]==node["id"], key] = node[key]
653-
654683
def set_edge_colors(self, colors):
655684
"""Set the color of each edge in the ARG
656685
@@ -660,7 +689,7 @@ def set_edge_colors(self, colors):
660689
ID of the edge and its new color
661690
"""
662691

663-
for id in colors:
692+
for id, val in colors:
664693
if id in self.edges["id"]:
665694
self.edges.loc[self.edges["id"]==id, "stroke"] = colors[id]
666695
else:

0 commit comments

Comments
 (0)