Skip to content

Commit 92d7b2e

Browse files
authored
Merge pull request #32 from max-models/bugfixes
Bugfixes and code formatting
2 parents fffec01 + c77fc25 commit 92d7b2e

6 files changed

Lines changed: 183 additions & 179 deletions

File tree

src/maxplotlib/canvas/canvas.py

Lines changed: 146 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,141 @@
1818
from maxplotlib.utils.options import Backends
1919

2020

21+
def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None):
22+
"""
23+
Plot all nodes and paths on the provided axis using Matplotlib.
24+
25+
Parameters:
26+
- ax (matplotlib.axes.Axes): Axis on which to plot the figure.
27+
"""
28+
29+
# TODO: Specify which layers to retreive nodes from with layers=layers
30+
nodes = tikzfigure.layers.get_nodes()
31+
paths = tikzfigure.layers.get_paths()
32+
33+
for path in paths:
34+
x_coords = [node.x for node in path.nodes]
35+
y_coords = [node.y for node in path.nodes]
36+
37+
# Parse path color
38+
path_color_spec = path.kwargs.get("color", "black")
39+
try:
40+
color = Color(path_color_spec).to_rgb()
41+
except ValueError as e:
42+
print(e)
43+
color = "black"
44+
45+
# Parse line width
46+
line_width_spec = path.kwargs.get("line_width", 1)
47+
if isinstance(line_width_spec, str):
48+
match = re.match(r"([\d.]+)(pt)?", line_width_spec)
49+
if match:
50+
line_width = float(match.group(1))
51+
else:
52+
print(
53+
f"Invalid line width specification: '{line_width_spec}', defaulting to 1",
54+
)
55+
line_width = 1
56+
else:
57+
line_width = float(line_width_spec)
58+
59+
# Parse line style using Linestyle class
60+
style_spec = path.kwargs.get("style", "solid")
61+
linestyle = Linestyle(style_spec).to_matplotlib()
62+
63+
ax.plot(
64+
x_coords,
65+
y_coords,
66+
color=color,
67+
linewidth=line_width,
68+
linestyle=linestyle,
69+
zorder=1, # Lower z-order to place behind nodes
70+
)
71+
72+
# Plot nodes after paths so they appear on top
73+
for node in nodes:
74+
# Determine shape and size
75+
shape = node.kwargs.get("shape", "circle")
76+
fill_color_spec = node.kwargs.get("fill", "white")
77+
edge_color_spec = node.kwargs.get("draw", "black")
78+
linewidth = float(node.kwargs.get("line_width", 1))
79+
size = float(node.kwargs.get("size", 1))
80+
81+
# Parse colors using the Color class
82+
try:
83+
facecolor = Color(fill_color_spec).to_rgb()
84+
except ValueError as e:
85+
print(e)
86+
facecolor = "white"
87+
88+
try:
89+
edgecolor = Color(edge_color_spec).to_rgb()
90+
except ValueError as e:
91+
print(e)
92+
edgecolor = "black"
93+
94+
# Plot shapes
95+
if shape == "circle":
96+
radius = size / 2
97+
circle = patches.Circle(
98+
(node.x, node.y),
99+
radius,
100+
facecolor=facecolor,
101+
edgecolor=edgecolor,
102+
linewidth=linewidth,
103+
zorder=2, # Higher z-order to place on top of paths
104+
)
105+
ax.add_patch(circle)
106+
elif shape == "rectangle":
107+
width = height = size
108+
rect = patches.Rectangle(
109+
(node.x - width / 2, node.y - height / 2),
110+
width,
111+
height,
112+
facecolor=facecolor,
113+
edgecolor=edgecolor,
114+
linewidth=linewidth,
115+
zorder=2, # Higher z-order
116+
)
117+
ax.add_patch(rect)
118+
else:
119+
# Default to circle if shape is unknown
120+
radius = size / 2
121+
circle = patches.Circle(
122+
(node.x, node.y),
123+
radius,
124+
facecolor=facecolor,
125+
edgecolor=edgecolor,
126+
linewidth=linewidth,
127+
zorder=2,
128+
)
129+
ax.add_patch(circle)
130+
131+
# Add text inside the shape
132+
if node.content:
133+
ax.text(
134+
node.x,
135+
node.y,
136+
node.content,
137+
fontsize=10,
138+
ha="center",
139+
va="center",
140+
wrap=True,
141+
zorder=3, # Even higher z-order for text
142+
)
143+
144+
# Remove axes, ticks, and legend
145+
ax.axis("off")
146+
147+
# Adjust plot limits
148+
all_x = [node.x for node in nodes]
149+
all_y = [node.y for node in nodes]
150+
padding = 1 # Adjust padding as needed
151+
ax.set_xlim(min(all_x) - padding, max(all_x) + padding)
152+
ax.set_ylim(min(all_y) - padding, max(all_y) + padding)
153+
ax.set_aspect("equal", adjustable="datalim")
154+
155+
21156
class Canvas:
22157
def __init__(
23158
self,
@@ -29,7 +164,7 @@ def __init__(
29164
label: str | None = None,
30165
fontsize: int = 14,
31166
dpi: int = 300,
32-
width: str = "17cm",
167+
width: str = "5cm",
33168
ratio: str = "golden", # TODO Add literal
34169
gridspec_kw: Dict = {"wspace": 0.08, "hspace": 0.1},
35170
):
@@ -62,6 +197,8 @@ def __init__(
62197
self._ratio = ratio
63198
self._gridspec_kw = gridspec_kw
64199
self._plotted = False
200+
self._matplotlib_fig = None
201+
self._matplotlib_axes = None
65202

66203
# Dictionary to store lines for each subplot
67204
# Key: (row, col), Value: list of lines with their data and kwargs
@@ -106,7 +243,6 @@ def add_line(
106243
subplot: LinePlot | None = None,
107244
row: int | None = None,
108245
col: int | None = None,
109-
plot_type="plot",
110246
**kwargs,
111247
):
112248
if row is not None and col is not None:
@@ -126,7 +262,6 @@ def add_line(
126262
x_data=x_data,
127263
y_data=y_data,
128264
layer=layer,
129-
plot_type=plot_type,
130265
**kwargs,
131266
)
132267

@@ -304,7 +439,7 @@ def show(
304439
elif backend == "plotly":
305440
self.plot_plotly(savefig=False)
306441
elif backend == "tikzpics":
307-
fig = self.plot_tikzpics(savefig=False)
442+
fig = self.plot_tikzpics(savefig=False, verbose=verbose)
308443
fig.show()
309444
else:
310445
raise ValueError("Invalid backend")
@@ -374,8 +509,8 @@ def plot_matplotlib(
374509

375510
def plot_tikzpics(
376511
self,
377-
savefig=None,
378-
verbose=False,
512+
savefig: str | None = None,
513+
verbose: bool = False,
379514
) -> TikzFigure:
380515
if len(self.subplots) > 1:
381516
raise NotImplementedError(
@@ -507,13 +642,6 @@ def label(self, value):
507642
def figsize(self, value):
508643
self._figsize = value
509644

510-
# Magic methods
511-
def __str__(self):
512-
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"
513-
514-
def __repr__(self):
515-
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"
516-
517645
def __getitem__(self, key):
518646
"""Allows accessing subplots by tuple index."""
519647
row, col = key
@@ -528,140 +656,12 @@ def __setitem__(self, key, value):
528656
raise IndexError("Subplot index out of range")
529657
self._subplot_matrix[row][col] = value
530658

659+
def __repr__(self):
660+
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, caption={self.caption}, label={self.label})"
531661

532-
def plot_matplotlib(tikzfigure: TikzFigure, ax, layers=None):
533-
"""
534-
Plot all nodes and paths on the provided axis using Matplotlib.
535-
536-
Parameters:
537-
- ax (matplotlib.axes.Axes): Axis on which to plot the figure.
538-
"""
539-
540-
# TODO: Specify which layers to retreive nodes from with layers=layers
541-
nodes = tikzfigure.layers.get_nodes()
542-
paths = tikzfigure.layers.get_paths()
543-
544-
for path in paths:
545-
x_coords = [node.x for node in path.nodes]
546-
y_coords = [node.y for node in path.nodes]
547-
548-
# Parse path color
549-
path_color_spec = path.kwargs.get("color", "black")
550-
try:
551-
color = Color(path_color_spec).to_rgb()
552-
except ValueError as e:
553-
print(e)
554-
color = "black"
555-
556-
# Parse line width
557-
line_width_spec = path.kwargs.get("line_width", 1)
558-
if isinstance(line_width_spec, str):
559-
match = re.match(r"([\d.]+)(pt)?", line_width_spec)
560-
if match:
561-
line_width = float(match.group(1))
562-
else:
563-
print(
564-
f"Invalid line width specification: '{line_width_spec}', defaulting to 1",
565-
)
566-
line_width = 1
567-
else:
568-
line_width = float(line_width_spec)
569-
570-
# Parse line style using Linestyle class
571-
style_spec = path.kwargs.get("style", "solid")
572-
linestyle = Linestyle(style_spec).to_matplotlib()
573-
574-
ax.plot(
575-
x_coords,
576-
y_coords,
577-
color=color,
578-
linewidth=line_width,
579-
linestyle=linestyle,
580-
zorder=1, # Lower z-order to place behind nodes
581-
)
582-
583-
# Plot nodes after paths so they appear on top
584-
for node in nodes:
585-
# Determine shape and size
586-
shape = node.kwargs.get("shape", "circle")
587-
fill_color_spec = node.kwargs.get("fill", "white")
588-
edge_color_spec = node.kwargs.get("draw", "black")
589-
linewidth = float(node.kwargs.get("line_width", 1))
590-
size = float(node.kwargs.get("size", 1))
591-
592-
# Parse colors using the Color class
593-
try:
594-
facecolor = Color(fill_color_spec).to_rgb()
595-
except ValueError as e:
596-
print(e)
597-
facecolor = "white"
598-
599-
try:
600-
edgecolor = Color(edge_color_spec).to_rgb()
601-
except ValueError as e:
602-
print(e)
603-
edgecolor = "black"
604-
605-
# Plot shapes
606-
if shape == "circle":
607-
radius = size / 2
608-
circle = patches.Circle(
609-
(node.x, node.y),
610-
radius,
611-
facecolor=facecolor,
612-
edgecolor=edgecolor,
613-
linewidth=linewidth,
614-
zorder=2, # Higher z-order to place on top of paths
615-
)
616-
ax.add_patch(circle)
617-
elif shape == "rectangle":
618-
width = height = size
619-
rect = patches.Rectangle(
620-
(node.x - width / 2, node.y - height / 2),
621-
width,
622-
height,
623-
facecolor=facecolor,
624-
edgecolor=edgecolor,
625-
linewidth=linewidth,
626-
zorder=2, # Higher z-order
627-
)
628-
ax.add_patch(rect)
629-
else:
630-
# Default to circle if shape is unknown
631-
radius = size / 2
632-
circle = patches.Circle(
633-
(node.x, node.y),
634-
radius,
635-
facecolor=facecolor,
636-
edgecolor=edgecolor,
637-
linewidth=linewidth,
638-
zorder=2,
639-
)
640-
ax.add_patch(circle)
641-
642-
# Add text inside the shape
643-
if node.content:
644-
ax.text(
645-
node.x,
646-
node.y,
647-
node.content,
648-
fontsize=10,
649-
ha="center",
650-
va="center",
651-
wrap=True,
652-
zorder=3, # Even higher z-order for text
653-
)
654-
655-
# Remove axes, ticks, and legend
656-
ax.axis("off")
657-
658-
# Adjust plot limits
659-
all_x = [node.x for node in nodes]
660-
all_y = [node.y for node in nodes]
661-
padding = 1 # Adjust padding as needed
662-
ax.set_xlim(min(all_x) - padding, max(all_x) + padding)
663-
ax.set_ylim(min(all_y) - padding, max(all_y) + padding)
664-
ax.set_aspect("equal", adjustable="datalim")
662+
# Magic methods
663+
def __str__(self):
664+
return f"Canvas(nrows={self.nrows}, ncols={self.ncols}, figsize={self.figsize})"
665665

666666

667667
if __name__ == "__main__":

src/maxplotlib/colors/colors.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,6 @@
55

66

77
class Color:
8-
def __init__(self, color_spec):
9-
"""
10-
Initialize the Color object by parsing the color specification.
11-
12-
Parameters:
13-
- color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name,
14-
an RGB tuple, a hex code, etc.
15-
"""
16-
self.color_spec = color_spec
17-
self.rgb = self._parse_color(color_spec)
188

199
def _parse_color(self, color_spec):
2010
"""
@@ -53,6 +43,17 @@ def _parse_color(self, color_spec):
5343
except ValueError:
5444
raise ValueError(f"Invalid color specification: '{color_spec}'")
5545

46+
def __init__(self, color_spec):
47+
"""
48+
Initialize the Color object by parsing the color specification.
49+
50+
Parameters:
51+
- color_spec: Can be a TikZ color string (e.g., 'blue!20'), a standard color name,
52+
an RGB tuple, a hex code, etc.
53+
"""
54+
self.color_spec = color_spec
55+
self.rgb = self._parse_color(color_spec)
56+
5657
def to_rgb(self):
5758
"""
5859
Return the color as an RGB tuple.

0 commit comments

Comments
 (0)