Skip to content

Commit 795b090

Browse files
committed
Update graph building
1 parent 9ba5c3a commit 795b090

9 files changed

Lines changed: 1832 additions & 9 deletions

File tree

arc/plotter.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from arc.species.perceive import perceive_molecule_from_xyz
5353
from arc.species.species import ARCSpecies, rmg_mol_to_dict_repr
54+
from arc.provenance.nodes import _enum_val, NodeType, EdgeType, DecisionKind
5455

5556

5657
PRETTY_UNITS = {'(s^-1)': r' (s$^-1$)',
@@ -73,15 +74,132 @@ def _wrap_graph_label(text: str, width: int = 24) -> str:
7374
for line in (textwrap.wrap(part, width=width) or ['']))
7475

7576

77+
def render_provenance_graph(prov_graph, run_label: str = 'ARC run') -> 'graphviz.Digraph':
78+
"""
79+
Render a :class:`ProvenanceGraph` as a Graphviz directed graph.
80+
81+
Node styling by type:
82+
- **species**: box / aliceblue
83+
- **calculation**: box / color by status (honeydew=done, mistyrose=errored, white=pending)
84+
- **data**: note / cornsilk
85+
- **decision**: diamond / color by kind (lavender, moccasin, mistyrose)
86+
87+
Edge styling by type:
88+
- ``selected_by``: solid green
89+
- ``rejected_by``: dashed red
90+
- ``troubleshot_by``: dashed orange
91+
- ``retried_as`` / ``fine_of``: dotted gray
92+
- others: solid black
93+
94+
Args:
95+
prov_graph: A :class:`ProvenanceGraph` instance.
96+
run_label (str): Label for the root run node.
97+
98+
Returns:
99+
graphviz.Digraph: The rendered graph object.
100+
"""
101+
gv = graphviz.Digraph(
102+
name='arc_provenance',
103+
comment=f'ARC provenance for {run_label}',
104+
graph_attr={'rankdir': 'LR', 'splines': 'true', 'overlap': 'false'},
105+
node_attr={'shape': 'box', 'style': 'rounded,filled', 'fillcolor': 'white', 'fontname': 'Helvetica'},
106+
edge_attr={'fontname': 'Helvetica'},
107+
)
108+
109+
# Node styling lookup
110+
_calc_colors = {'done': 'honeydew', 'errored': 'mistyrose', 'pending': 'white'}
111+
_decision_colors = {
112+
'ts_guess_selection': 'lavender',
113+
'ts_guess_selection_failed': 'mistyrose',
114+
'job_troubleshooting': 'moccasin',
115+
'conformer_selection': 'lavender',
116+
'ts_guess_clustering': 'lavender',
117+
'ts_method_spawning': 'lavender',
118+
'ts_validation_freq': 'lightyellow',
119+
'ts_validation_nmd': 'lightyellow',
120+
'ts_validation_irc': 'lightyellow',
121+
'ts_switch': 'mistyrose',
122+
}
123+
124+
# Edge styling lookup
125+
_edge_styles = {
126+
'selected_by': {'color': 'green3', 'style': 'solid'},
127+
'rejected_by': {'color': 'red', 'style': 'dashed'},
128+
'troubleshot_by': {'color': 'orange', 'style': 'dashed'},
129+
'triggered_by': {'color': 'gray40', 'style': 'solid'},
130+
'retried_as': {'color': 'gray60', 'style': 'dotted'},
131+
'fine_of': {'color': 'gray60', 'style': 'dotted'},
132+
'spawned_by': {'color': 'blue', 'style': 'solid'},
133+
}
134+
135+
for node in prov_graph.nodes.values():
136+
nid = _sanitize_graphviz_id(node.node_id)
137+
ntype = node.node_type
138+
139+
if ntype == 'species':
140+
lbl = node.label or node.node_id
141+
is_ts = (node.metadata or {}).get('is_ts', False)
142+
if is_ts:
143+
lbl += '\nTS'
144+
gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor='aliceblue')
145+
146+
elif ntype == 'calculation':
147+
parts = [getattr(node, 'job_type', '') or '', getattr(node, 'job_name', '') or '']
148+
if getattr(node, 'job_adapter', None):
149+
parts.append(node.job_adapter)
150+
if getattr(node, 'level', None):
151+
parts.append(node.level)
152+
lbl = '\n'.join(p for p in parts if p)
153+
status = getattr(node, 'status', 'pending') or 'pending'
154+
fillcolor = _calc_colors.get(status, 'white')
155+
gv.node(nid, _wrap_graph_label(lbl), shape='box', fillcolor=fillcolor)
156+
157+
elif ntype == 'data':
158+
dk = getattr(node, 'data_kind', '') or ''
159+
val = getattr(node, 'value', None)
160+
lbl = dk
161+
if val is not None and not isinstance(val, (list, dict)):
162+
lbl += f'\n{val}'
163+
gv.node(nid, _wrap_graph_label(lbl), shape='note', fillcolor='cornsilk')
164+
165+
elif ntype == 'decision':
166+
dk = getattr(node, 'decision_kind', '') or ''
167+
outcome = getattr(node, 'outcome', '') or ''
168+
lbl = dk.replace('_', ' ')
169+
if outcome:
170+
lbl += f'\n{outcome}'
171+
fillcolor = _decision_colors.get(dk, 'lavender')
172+
gv.node(nid, _wrap_graph_label(lbl, width=28), shape='diamond', fillcolor=fillcolor)
173+
174+
else:
175+
gv.node(nid, _wrap_graph_label(node.node_id))
176+
177+
for edge in prov_graph.edges:
178+
src = _sanitize_graphviz_id(edge.source_id)
179+
tgt = _sanitize_graphviz_id(edge.target_id)
180+
etype = edge.edge_type
181+
style_attrs = _edge_styles.get(etype, {})
182+
label = etype.replace('_', ' ') if etype not in ('belongs_to', 'input_of', 'output_of') else ''
183+
gv.edge(src, tgt, label=label, **style_attrs)
184+
185+
return gv
186+
187+
76188
def save_provenance_artifacts(project_directory: str,
77189
provenance: dict,
190+
graph=None,
78191
) -> dict:
79192
"""
80193
Save provenance YAML and render Graphviz artifacts for an ARC run.
81194
195+
When a ``graph`` (:class:`ProvenanceGraph`) is provided, the Graphviz
196+
visualization is built from the graph's typed nodes and edges rather
197+
than the flat event list, producing richer diagrams.
198+
82199
Args:
83200
project_directory (str): The ARC project directory.
84201
provenance (dict): A provenance dictionary with an ``events`` list.
202+
graph: Optional ProvenanceGraph instance for graph-based rendering.
85203
86204
Returns:
87205
dict: Paths to generated artifacts.
@@ -99,6 +217,23 @@ def save_provenance_artifacts(project_directory: str,
99217
save_yaml_file(path=yml_path, content=provenance)
100218
return {'yml': yml_path, 'dot': None, 'svg': None}
101219

220+
# Prefer graph-based rendering when a ProvenanceGraph is available.
221+
if graph is not None and len(graph) > 0:
222+
gv_graph = render_provenance_graph(graph, run_label=run_label)
223+
with open(dot_path, 'w') as f:
224+
f.write(gv_graph.source)
225+
try:
226+
svg_data = gv_graph.pipe(format='svg')
227+
except (graphviz.ExecutableNotFound, graphviz.CalledProcessError):
228+
logger.warning('Could not render ARC provenance SVG because Graphviz is not available on this system.')
229+
else:
230+
with open(svg_path, 'wb') as f:
231+
f.write(svg_data)
232+
provenance['updated_at'] = datetime.datetime.now().isoformat(timespec='seconds')
233+
save_yaml_file(path=yml_path, content=provenance)
234+
return {'yml': yml_path, 'dot': dot_path, 'svg': svg_path if os.path.isfile(svg_path) else None}
235+
236+
# Fallback: event-based rendering (legacy path).
102237
graph = graphviz.Digraph(
103238
name='arc_provenance',
104239
comment=f'ARC provenance for {run_label}',

arc/plotter_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
import shutil
1010
import unittest
1111

12+
try:
13+
import graphviz
14+
except ImportError:
15+
graphviz = None
16+
1217
import arc.plotter as plotter
1318
from arc.common import ARC_PATH, ARC_TESTING_PATH, read_yaml_file, safe_copy_file
1419
from arc.species.converter import str_to_xyz
@@ -300,6 +305,57 @@ def test_save_provenance_artifacts(self):
300305
# Troubleshoot follow-up connects from the decision diamond, not the species node.
301306
self.assertIn('decision_7 -> job_spc1_freq_a3', dot)
302307

308+
def test_render_provenance_graph(self):
309+
"""Test Graphviz rendering from a ProvenanceGraph object."""
310+
from arc.provenance import (ProvenanceGraph, DecisionKind, DataKind, EdgeType)
311+
g = ProvenanceGraph(project='render_test')
312+
sid = g.add_species_node(label='ethanol')
313+
cid = g.add_calculation_node(label='ethanol', job_name='opt_a1',
314+
job_type='opt', job_adapter='gaussian',
315+
level='b3lyp/6-31g(d)', status='done')
316+
did = g.add_data_node(label='ethanol', data_kind=DataKind.energy, value=-79.5)
317+
dec = g.add_decision_node(label='ethanol',
318+
decision_kind=DecisionKind.conformer_selection,
319+
outcome='Selected conformer #0')
320+
g.add_edge(sid, cid, EdgeType.input_of)
321+
g.add_edge(cid, did, EdgeType.output_of)
322+
g.add_edge(did, dec, EdgeType.selected_by)
323+
324+
if graphviz is not None:
325+
gv = plotter.render_provenance_graph(g, run_label='render_test')
326+
dot_source = gv.source
327+
self.assertIn('ethanol', dot_source)
328+
self.assertIn('opt', dot_source)
329+
self.assertIn('energy', dot_source)
330+
self.assertIn('conformer selection', dot_source)
331+
self.assertIn('honeydew', dot_source) # done calc
332+
self.assertIn('cornsilk', dot_source) # data node
333+
self.assertIn('diamond', dot_source) # decision node
334+
self.assertIn('green3', dot_source) # selected_by edge
335+
336+
def test_save_provenance_artifacts_with_graph(self):
337+
"""Test that save_provenance_artifacts prefers graph-based rendering when a graph is provided."""
338+
from arc.provenance import (ProvenanceGraph, DecisionKind, EdgeType)
339+
project = 'arc_project_for_testing_delete_after_usage'
340+
project_directory = os.path.join(ARC_PATH, 'Projects', project)
341+
g = ProvenanceGraph(project=project)
342+
sid = g.add_species_node(label='spc1')
343+
cid = g.add_calculation_node(label='spc1', job_name='opt_a1',
344+
job_type='opt', status='done')
345+
g.add_edge(sid, cid, EdgeType.input_of)
346+
provenance = {'project': project, 'events': []}
347+
paths = plotter.save_provenance_artifacts(
348+
project_directory=project_directory,
349+
provenance=provenance,
350+
graph=g,
351+
)
352+
self.assertTrue(os.path.isfile(paths['yml']))
353+
if paths['dot'] is not None:
354+
with open(paths['dot'], 'r') as f:
355+
dot = f.read()
356+
# Graph-based rendering uses node IDs like species_1 not event-based species_spc1.
357+
self.assertIn('species_1', dot)
358+
self.assertIn('honeydew', dot)
303359

304360
@classmethod
305361
def tearDownClass(cls):

arc/provenance/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
ARC provenance subpackage — directed acyclic graph for computational provenance.
3+
4+
Tracks the full chain of inputs, calculations, decisions, and outputs that
5+
produce ARC's results. Inspired by AiiDA's DAG model but adapted for ARC's
6+
branching decision trees (TS guess evaluation, conformer selection,
7+
troubleshooting loops).
8+
9+
Submodules:
10+
- ``nodes``: Node types, edge types, and their data classes.
11+
- ``graph``: ProvenanceGraph container with query and serialization.
12+
"""
13+
14+
from arc.provenance.graph import ProvenanceGraph
15+
from arc.provenance.nodes import (
16+
CalculationNode,
17+
DataKind,
18+
DataNode,
19+
DecisionKind,
20+
DecisionNode,
21+
EdgeType,
22+
NodeType,
23+
ProvenanceEdge,
24+
ProvenanceNode,
25+
)
26+
27+
__all__ = [
28+
'ProvenanceGraph',
29+
'ProvenanceNode',
30+
'CalculationNode',
31+
'DataNode',
32+
'DecisionNode',
33+
'ProvenanceEdge',
34+
'NodeType',
35+
'DataKind',
36+
'DecisionKind',
37+
'EdgeType',
38+
]

0 commit comments

Comments
 (0)