@@ -635,22 +635,51 @@ def set_all_node_styles(self, size=None, symbol=None, fill=None, stroke=None, st
635635 self .nodes ["stroke" ] = stroke
636636 if stroke_width != None :
637637 self .nodes ["stroke_width" ] = stroke_width
638-
638+
639+ @staticmethod
640+ def _set_styles (dataframe , styles , allowed_keys ):
641+ # NB: this is very inefficient for lots of nodes. If you want to style e.g.
642+ # a set of nodes, it is quicker to use something like
643+ # d3arg.nodes.loc[np.isin(self.nodes["id"], node_ids), "size"] = size
644+ for item_id , style in styles .items ():
645+ if 'id' in dataframe .columns :
646+ use = dataframe ["id" ] == item_id
647+ else :
648+ use = item_id
649+ for k in style .keys ():
650+ if k not in allowed_keys :
651+ raise ValueError (
652+ f"Invalid key '{ k } ' in styles. Allowed keys are { allowed_keys } ." )
653+ dataframe .loc [use , list (style .keys ())] = list (style .values ())
654+
639655 def set_node_styles (self , styles ):
640656 """Individually control the styling of each node.
641657
642658 Parameters
643659 ----------
644- styles : list
645- List of dicts, one per node, with the styling keys: id, size, symbol, fill, stroke, stroke_width.
646- "id" is the only mandatory key. Only nodes that need styles updated need to be provided.
660+ styles : dict
661+ A dictionary whose keys are node ids, and whose values are separate dicts
662+ containing any of the styling keys: size, symbol, fill, stroke, stroke_width.
663+ Only nodes that need styles updated need to be provided, and all styling
664+ keys are optional.
647665 """
666+ allowed_keys = {"size" , "symbol" , "fill" , "stroke" , "stroke_width" }
667+ self ._set_styles (self .nodes , styles , allowed_keys )
668+
669+ def set_mutation_styles (self , styles ):
670+ """Individually control the styling of each mutation.
671+
672+ Parameters
673+ ----------
674+ styles : dict
675+ A dictionary whose keys are mutation ids, and whose values are separate
676+ dictionaries containing any of the styling keys: size, fill, stroke.
677+ Only mutations that need styles updated need to be provided, and all
678+ styling keys are optional.
679+ """
680+ allowed_keys = {"fill" , "stroke" , "size" }
681+ self ._set_styles (self .mutations , styles , allowed_keys )
648682
649- for node in styles :
650- for key in node .keys ():
651- if key in ["size" , "symbol" , "fill" , "stroke" , "stroke_width" ]:
652- self .nodes .loc [self .nodes ["id" ]== node ["id" ], key ] = node [key ]
653-
654683 def set_edge_colors (self , colors ):
655684 """Set the color of each edge in the ARG
656685
@@ -660,7 +689,7 @@ def set_edge_colors(self, colors):
660689 ID of the edge and its new color
661690 """
662691
663- for id in colors :
692+ for id , val in colors :
664693 if id in self .edges ["id" ]:
665694 self .edges .loc [self .edges ["id" ]== id , "stroke" ] = colors [id ]
666695 else :
0 commit comments