-
Notifications
You must be signed in to change notification settings - Fork 208
Expand file tree
/
Copy pathembedded_modules.py
More file actions
299 lines (250 loc) · 10.3 KB
/
Copy pathembedded_modules.py
File metadata and controls
299 lines (250 loc) · 10.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""Auto-discover and embed all runtime helpers as a single block.
Instead of manually curating a __all__ list per function (which breaks when
new internal helpers are added), this module reads contributing source files,
strips their import statements and duplicate module logger setup via AST, and
returns the clean definitions ready for embedding in generated standalone scripts.
This guarantees that ALL functions, classes, and constants from the
contributing modules are embedded together — so internal cross-calls always
resolve without NameError.
"""
from __future__ import annotations
import ast
from pathlib import Path
# Modules whose top-level definitions should be embedded in generated scripts.
# Order matters: dependencies first, so functions are defined before callers use them.
_SOURCE_FILES: list[str] = [
"runtime/module_loader.py", # _load_module, _bootstrap_import — no internal deps
"runtime/path_discovery.py", # get_comfyui_path, find_path, etc.
"runtime/bootstrap.py", # CLI filtering; depends on module_loader
"node_runtime.py", # public API facade + bootstrap/cleanup
]
APPROVED_EMBEDDED_NAMES: frozenset[str] = frozenset(
{
"_apply_device_settings",
"_apply_directory_overrides",
"_bootstrap_import",
"_discover_comfyui_cli_options",
"_filter_comfyui_args",
"_find_file",
"_find_from_extension_location",
"_get_base_option",
"_init_extra_nodes",
"_is_comfyui_directory",
"_load_custom_node_modules",
"_load_module",
"_load_module_temp",
"_parse_parser_actions",
"add_comfyui_directory_to_sys_path",
"add_extra_model_paths",
"bootstrap_comfyui_runtime",
"cleanup_comfyui_runtime",
"find_path",
"get_comfyui_path",
"get_node_class_mappings",
"get_value_at_index",
"import_custom_nodes",
}
)
def _strip_imports(source: str) -> str:
"""Remove non-embeddable top-level statements from Python source code.
Uses AST to find import nodes and module logger assignments, then rebuilds
the source with those lines removed while preserving everything else
(functions, classes, constants, docstrings, comments).
Args:
source: Full Python source code of a module.
Returns:
Source code with all top-level import statements removed.
"""
tree = ast.parse(source)
lines = source.splitlines(keepends=True)
# Collect line numbers (1-indexed) of top-level statements to remove.
skip_lines: set[int] = set()
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)) or (
isinstance(node, (ast.Assign, ast.AnnAssign))
and _is_module_logger_assignment(node)
):
# Handle multi-line imports (from x import (a,\n b))
start = node.lineno or 1
end = getattr(node, "end_lineno", start) or start
for ln in range(start, end + 1):
skip_lines.add(ln)
# Also strip blank lines that immediately follow removed imports
# to avoid excessive whitespace gaps
result_lines: list[str] = []
prev_was_import = False
for i, line in enumerate(lines, start=1):
if i in skip_lines:
prev_was_import = True
continue
# Skip a single blank line after imports (but keep meaningful spacing)
if prev_was_import and line.strip() == "":
prev_was_import = False
continue
prev_was_import = False
result_lines.append(line)
return "".join(result_lines).strip() + "\n"
def _is_module_logger_assignment(node: ast.Assign | ast.AnnAssign) -> bool:
"""Return True when a top-level assignment only initializes `log`."""
if isinstance(node, ast.Assign):
targets = node.targets
value = node.value
else:
targets = [node.target]
value = node.value
if value is None:
return False
if not targets or any(
not isinstance(target, ast.Name) or target.id != "log" for target in targets
):
return False
return (
isinstance(value, ast.Call)
and isinstance(value.func, ast.Attribute)
and value.func.attr == "getLogger"
and isinstance(value.func.value, ast.Name)
and value.func.value.id == "logging"
)
def get_embedded_helpers() -> str:
"""Return the full embedded helper block for generated scripts.
Reads each source file listed in _SOURCE_FILES, strips its import
statements, and concatenates the results into a single embeddable
code block.
Returns:
Python source code containing all function/class/constant definitions
from contributing modules, ready to paste into a generated script.
"""
package_root = Path(__file__).resolve().parent.parent # comfyui_to_python/
parts: list[str] = []
for rel_path in _SOURCE_FILES:
filepath = package_root / rel_path
if not filepath.exists():
raise FileNotFoundError(
f"Embedded source file not found: {filepath}\n"
f"If you added a new contributing module, update "
f"_SOURCE_FILES in {__file__}"
)
parts.append(f"# --- Embedded from {rel_path} ---\n")
parts.append(_strip_imports(filepath.read_text()))
return "\n".join(parts)
def verify_embedded_surface_matches_manifest() -> list[str]:
"""Return embedded helper names that differ from the approved surface."""
actual = list_embedded_names()
differences = [
*(
f"missing approved embedded name: {name}"
for name in sorted(APPROVED_EMBEDDED_NAMES - actual)
),
*(
f"unexpected embedded name: {name}"
for name in sorted(actual - APPROVED_EMBEDDED_NAMES)
),
]
return differences
def list_embedded_names() -> set[str]:
"""Return the set of all top-level names that will be embedded.
Useful for testing / verification to ensure no unexpected names are
included and to check for missing dependencies.
Returns:
Set of function/class/constant names from contributing modules.
"""
package_root = Path(__file__).resolve().parent.parent
names: set[str] = set()
for rel_path in _SOURCE_FILES:
filepath = package_root / rel_path
tree = ast.parse(filepath.read_text())
for node in ast.iter_child_nodes(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
names.add(node.name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
# Skip common non-embeddable names like 'log'
if target.id not in ("log",):
names.add(target.id)
return names
def verify_no_missing_cross_calls() -> list[str]:
"""Check that all function calls within embedded code resolve to embedded names.
Scans each contributing module for calls to names defined at module level
in other modules, and reports any that are NOT in the embedded set.
Skips builtins, exceptions, type hints, and local/nested function references
that are created dynamically (e.g., via getattr on runtime-loaded modules).
Returns:
List of unresolved call names (empty if everything resolves).
"""
import builtins as _builtins
embedded = list_embedded_names()
package_root = Path(__file__).resolve().parent.parent
builtin_names = set(dir(_builtins))
# Exception classes that are common in error handling
exception_names = {
"Exception",
"BaseException",
"ValueError",
"TypeError",
"KeyError",
"AttributeError",
"ModuleNotFoundError",
"ImportError",
"FileNotFoundError",
"RuntimeError",
"StopIteration",
"IndexError",
"OSError",
}
# Type hint names from typing module
typing_names = {
"Any",
"Sequence",
"Mapping",
"Union",
"Optional",
"List",
"Dict",
"Set",
"Tuple",
"FrozenSet",
"Callable",
}
all_known = builtin_names | exception_names | typing_names | embedded
unresolved: list[str] = []
for rel_path in _SOURCE_FILES:
filepath = package_root / rel_path
tree = ast.parse(filepath.read_text())
# Find all Name nodes used as function calls
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
caller_name = node.func.id
if caller_name not in all_known:
# Could be a local/nested function or dynamic lookup — check if
# it's defined as a nested def anywhere in the same file
is_nested = _is_nested_or_local_def(filepath, caller_name)
if not is_nested:
unresolved.append(
f"{rel_path}: calls '{caller_name}' "
"(not embedded, not builtin)"
)
return unresolved
def _is_nested_or_local_def(filepath: Path, name: str) -> bool:
"""Check if a name is defined as a nested function or local variable in a file.
Catches patterns like:
def outer():
if cond:
x = getattr(mod, "x") # 'x' is a local var (nested inside if)
return x()
"""
tree = ast.parse(filepath.read_text())
nested_defs: set[str] = set()
local_assigns: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Walk ALL descendants of the function body to find nested defs
# and local variable assignments (including those inside if/else blocks)
for child in ast.walk(node):
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)):
nested_defs.add(child.name)
elif isinstance(child, ast.Assign):
for target in child.targets:
if isinstance(target, ast.Name):
local_assigns.add(target.id)
return name in nested_defs or name in local_assigns