Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions python/tskit/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def tostring(self):


class Drawing:
def __init__(self, size=None, debug=False, preamble=None, **kwargs):
def __init__(self, size=None, **kwargs):
kwargs = {
"version": "1.1",
"xmlns": "http://www.w3.org/2000/svg",
Expand All @@ -148,8 +148,7 @@ def __init__(self, size=None, debug=False, preamble=None, **kwargs):
kwargs["height"] = size[1]

self.root = Element("svg", **kwargs)
if preamble is not None:
self.root.add(preamble)
self.root.add("") # First root elem is a blank preamble
self.defs = Element("defs")
self.root.add(self.defs)

Expand Down Expand Up @@ -860,6 +859,8 @@ class SvgPlot:

text_height = 14 # May want to calculate this based on a font size
line_height = text_height * 1.2 # allowing padding above and below a line
default_width = 200 # for a single tree
default_height = 200

def __init__(
self,
Expand All @@ -879,10 +880,9 @@ def __init__(
root_svg_attributes = {}
if canvas_size is None:
canvas_size = size
dwg = Drawing(
size=canvas_size, debug=True, preamble=preamble, **root_svg_attributes
)
dwg = Drawing(size=canvas_size, **root_svg_attributes)

self.preamble = preamble
self.image_size = size
self.plotbox = Plotbox(size)
self.root_groups = {}
Expand All @@ -892,6 +892,15 @@ def __init__(
self.dwg_base = dwg.add(dwg.g(class_=svg_class))
self.drawing = dwg

def draw(self, path=None):
if self.preamble is not None:
self.drawing.root.children[0] = self.preamble
output = self.drawing.tostring()
if path is not None:
# TODO remove the 'pretty' when we are done debugging this.
self.drawing.saveas(path, pretty=True)
return SVGString(output)

def get_plotbox(self):
"""
Get the svgwrite plotbox, creating it if necessary.
Expand Down Expand Up @@ -1361,7 +1370,7 @@ def __init__(
use_skipped = np.append(np.diff(self.tree_status & OMIT_MIDDLE == 0) == 1, 0)
num_plotboxes = np.sum(np.logical_or(use_tree, use_skipped))
if size is None:
size = (200 * int(num_plotboxes), 200)
size = (self.default_width * int(num_plotboxes), self.default_height)
if max_time is None:
max_time = "ts"
if min_time is None:
Expand Down Expand Up @@ -1674,7 +1683,7 @@ def __init__(
stacklevel=4,
)
if size is None:
size = (200, 200)
size = (self.default_width, self.default_height)
if symbol_size is None:
symbol_size = 6
self.symbol_size = symbol_size
Expand Down
16 changes: 4 additions & 12 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,7 +1964,7 @@ def draw_svg(
:return: An SVG representation of a tree.
:rtype: SVGString
"""
draw = drawing.SvgTree(
svgtree = drawing.SvgTree(
self,
size,
time_scale=time_scale,
Expand Down Expand Up @@ -1995,11 +1995,7 @@ def draw_svg(
preamble=preamble,
**kwargs,
)
output = draw.drawing.tostring()
if path is not None:
# TODO: removed the pretty here when this is stable.
draw.drawing.saveas(path, pretty=True)
return drawing.SVGString(output)
return svgtree.draw(path)

def draw(
self,
Expand Down Expand Up @@ -7599,7 +7595,7 @@ def draw_svg(
strictly within an empty region then that tree will not be plotted on the
right hand side, and the X axis will end at ``empty_tree.interval.left``
"""
draw = drawing.SvgTreeSequence(
svgtreesequence = drawing.SvgTreeSequence(
self,
size,
x_scale=x_scale,
Expand Down Expand Up @@ -7629,11 +7625,7 @@ def draw_svg(
preamble=preamble,
**kwargs,
)
output = draw.drawing.tostring()
if path is not None:
# TODO remove the 'pretty' when we are done debugging this.
draw.drawing.saveas(path, pretty=True)
return drawing.SVGString(output)
return svgtreesequence.draw(path)

def draw_text(
self,
Expand Down
Loading