Skip to content

Commit 570fc7d

Browse files
committed
fix dependency order
1 parent 0d08734 commit 570fc7d

File tree

1 file changed

+140
-93
lines changed

1 file changed

+140
-93
lines changed

.kokoro/determine_dependencies.py

Lines changed: 140 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,13 @@
2121
from typing import Dict, List, Set, Tuple
2222

2323
# Maven XML namespace
24-
NS = {'mvn': 'http://maven.apache.org/POM/4.0.0'}
24+
NS = {"mvn": "http://maven.apache.org/POM/4.0.0"}
25+
2526

2627
class Module:
27-
def __init__(self, path: str, group_id: str, artifact_id: str, parent: Tuple[str, str] = None):
28+
def __init__(
29+
self, path: str, group_id: str, artifact_id: str, parent: Tuple[str, str] = None
30+
):
2831
self.path = path
2932
self.group_id = group_id
3033
self.artifactId = artifact_id
@@ -38,6 +41,7 @@ def key(self) -> Tuple[str, str]:
3841
def __repr__(self):
3942
return f"{self.group_id}:{self.artifactId}"
4043

44+
4145
def parse_pom(path: str) -> Module:
4246
try:
4347
tree = ET.parse(path)
@@ -49,23 +53,23 @@ def parse_pom(path: str) -> Module:
4953
# Handle namespace if present
5054
# XML tags in ElementTree are {namespace}tag
5155
# We'll use find with namespaces for robustness, but simple logic for extraction
52-
56+
5357
# Helper to clean tag name
5458
def local_name(tag):
55-
if '}' in tag:
56-
return tag.split('}', 1)[1]
59+
if "}" in tag:
60+
return tag.split("}", 1)[1]
5761
return tag
5862

59-
parent_elem = root.find('mvn:parent', NS)
63+
parent_elem = root.find("mvn:parent", NS)
6064
parent_coords = None
6165
parent_group_id = None
6266
if parent_elem is not None:
63-
p_group = parent_elem.find('mvn:groupId', NS).text
64-
p_artifact = parent_elem.find('mvn:artifactId', NS).text
67+
p_group = parent_elem.find("mvn:groupId", NS).text
68+
p_artifact = parent_elem.find("mvn:artifactId", NS).text
6569
parent_coords = (p_group, p_artifact)
6670
parent_group_id = p_group
6771

68-
group_id_elem = root.find('mvn:groupId', NS)
72+
group_id_elem = root.find("mvn:groupId", NS)
6973
# Inherit groupId from parent if not specified
7074
if group_id_elem is not None:
7175
group_id = group_id_elem.text
@@ -75,70 +79,80 @@ def local_name(tag):
7579
# Fallback or error? For now, use artifactId as heuristic or empty
7680
group_id = "unknown"
7781

78-
artifact_id = root.find('mvn:artifactId', NS).text
79-
82+
artifact_id = root.find("mvn:artifactId", NS).text
83+
8084
module = Module(path, group_id, artifact_id, parent_coords)
8185

8286
# Dependencies
83-
def add_dependencies(section):
87+
def add_dependencies(section, is_management=False):
8488
if section is not None:
85-
for dep in section.findall('mvn:dependency', NS):
86-
d_group = dep.find('mvn:groupId', NS)
87-
d_artifact = dep.find('mvn:artifactId', NS)
89+
for dep in section.findall("mvn:dependency", NS):
90+
if is_management:
91+
# Only include 'import' scope dependencies from dependencyManagement
92+
scope = dep.find("mvn:scope", NS)
93+
if scope is None or scope.text != "import":
94+
continue
95+
96+
d_group = dep.find("mvn:groupId", NS)
97+
d_artifact = dep.find("mvn:artifactId", NS)
8898
if d_group is not None and d_artifact is not None:
8999
module.dependencies.add((d_group.text, d_artifact.text))
90100

91-
add_dependencies(root.find('mvn:dependencies', NS))
92-
93-
dep_mgmt = root.find('mvn:dependencyManagement', NS)
101+
add_dependencies(root.find("mvn:dependencies", NS))
102+
103+
dep_mgmt = root.find("mvn:dependencyManagement", NS)
94104
if dep_mgmt is not None:
95-
add_dependencies(dep_mgmt.find('mvn:dependencies', NS))
105+
add_dependencies(dep_mgmt.find("mvn:dependencies", NS), is_management=True)
96106

97107
# Plugin dependencies
98-
build = root.find('mvn:build', NS)
108+
build = root.find("mvn:build", NS)
99109
if build is not None:
100-
plugins = build.find('mvn:plugins', NS)
110+
plugins = build.find("mvn:plugins", NS)
101111
if plugins is not None:
102-
for plugin in plugins.findall('mvn:plugin', NS):
112+
for plugin in plugins.findall("mvn:plugin", NS):
103113
# Plugin itself
104-
p_group = plugin.find('mvn:groupId', NS)
105-
p_artifact = plugin.find('mvn:artifactId', NS)
114+
p_group = plugin.find("mvn:groupId", NS)
115+
p_artifact = plugin.find("mvn:artifactId", NS)
106116
if p_group is not None and p_artifact is not None:
107117
module.dependencies.add((p_group.text, p_artifact.text))
108-
118+
109119
# Plugin dependencies
110-
add_dependencies(plugin.find('mvn:dependencies', NS))
111-
120+
add_dependencies(plugin.find("mvn:dependencies", NS))
121+
112122
# Plugin Management
113-
plugin_mgmt = build.find('mvn:pluginManagement', NS)
123+
plugin_mgmt = build.find("mvn:pluginManagement", NS)
114124
if plugin_mgmt is not None:
115-
plugins = plugin_mgmt.find('mvn:plugins', NS)
116-
if plugins is not None:
117-
for plugin in plugins.findall('mvn:plugin', NS):
125+
plugins = plugin_mgmt.find("mvn:plugins", NS)
126+
if plugins is not None:
127+
for plugin in plugins.findall("mvn:plugin", NS):
118128
# Plugin itself
119-
p_group = plugin.find('mvn:groupId', NS)
120-
p_artifact = plugin.find('mvn:artifactId', NS)
129+
p_group = plugin.find("mvn:groupId", NS)
130+
p_artifact = plugin.find("mvn:artifactId", NS)
121131
if p_group is not None and p_artifact is not None:
122132
module.dependencies.add((p_group.text, p_artifact.text))
123133

124-
add_dependencies(plugin.find('mvn:dependencies', NS))
134+
add_dependencies(plugin.find("mvn:dependencies", NS))
125135

126136
return module
127137

138+
128139
def find_poms(root_dir: str) -> List[str]:
129140
pom_files = []
130141
for dirpath, dirnames, filenames in os.walk(root_dir):
131142
# Skip hidden directories and known non-module dirs
132-
dirnames[:] = [d for d in dirnames if not d.startswith('.')]
133-
134-
if 'pom.xml' in filenames:
135-
pom_files.append(os.path.join(dirpath, 'pom.xml'))
143+
dirnames[:] = [d for d in dirnames if not d.startswith(".")]
144+
145+
if "pom.xml" in filenames:
146+
pom_files.append(os.path.join(dirpath, "pom.xml"))
136147
return pom_files
137148

138-
def build_graph(root_dir: str) -> Tuple[Dict[Tuple[str, str], Module], Dict[Tuple[str, str], Set[Tuple[str, str]]]]:
149+
150+
def build_graph(
151+
root_dir: str,
152+
) -> Tuple[Dict[Tuple[str, str], Module], Dict[Tuple[str, str], Set[Tuple[str, str]]]]:
139153
pom_paths = find_poms(root_dir)
140154
modules: Dict[Tuple[str, str], Module] = {}
141-
155+
142156
# First pass: load all modules
143157
for path in pom_paths:
144158
module = parse_pom(path)
@@ -148,26 +162,27 @@ def build_graph(root_dir: str) -> Tuple[Dict[Tuple[str, str], Module], Dict[Tupl
148162
# Build adjacency list: dependent -> dependencies (upstream)
149163
# Only include dependencies that are present in the repo
150164
graph: Dict[Tuple[str, str], Set[Tuple[str, str]]] = defaultdict(set)
151-
165+
152166
for key, module in modules.items():
153167
# Parent dependency
154168
if module.parent and module.parent in modules:
155169
graph[key].add(module.parent)
156-
170+
157171
# Regular dependencies
158172
for dep_key in module.dependencies:
159173
if dep_key in modules:
160174
graph[key].add(dep_key)
161-
175+
162176
return modules, graph
163177

178+
164179
def get_transitive_dependencies(
165-
start_nodes: List[Tuple[str, str]],
166-
graph: Dict[Tuple[str, str], Set[Tuple[str, str]]]
180+
start_nodes: List[Tuple[str, str]],
181+
graph: Dict[Tuple[str, str], Set[Tuple[str, str]]],
167182
) -> Set[Tuple[str, str]]:
168183
visited = set()
169184
stack = list(start_nodes)
170-
185+
171186
while stack:
172187
node = stack.pop()
173188
if node not in visited:
@@ -177,94 +192,126 @@ def get_transitive_dependencies(
177192
for upstream in graph[node]:
178193
if upstream not in visited:
179194
stack.append(upstream)
180-
195+
181196
return visited
182197

183-
def resolve_modules_from_inputs(inputs: List[str], modules_by_path: Dict[str, Module], modules_by_key: Dict[Tuple[str, str], Module]) -> List[Tuple[str, str]]:
198+
199+
def resolve_modules_from_inputs(
200+
inputs: List[str],
201+
modules_by_path: Dict[str, Module],
202+
modules_by_key: Dict[Tuple[str, str], Module],
203+
) -> List[Tuple[str, str]]:
184204
resolved = set()
185205
for item in inputs:
186206
# Check if item is a path
187207
abs_item = os.path.abspath(item)
188-
208+
189209
# If it's a file, try to find the nearest pom.xml
190-
if os.path.isfile(abs_item) or (not item.endswith('pom.xml') and os.path.isdir(abs_item)):
191-
# Heuristic: if it's a file, find containing pom
192-
# if it's a dir, look for pom.xml inside or check if it matches a module path
193-
candidate_path = abs_item
194-
if os.path.isfile(candidate_path) and not candidate_path.endswith('pom.xml'):
195-
candidate_path = os.path.dirname(candidate_path)
196-
197-
# Traverse up to find pom.xml
198-
while candidate_path.startswith(os.getcwd()) and len(candidate_path) >= len(os.getcwd()):
199-
pom_path = os.path.join(candidate_path, 'pom.xml')
200-
if pom_path in modules_by_path:
201-
resolved.add(modules_by_path[pom_path].key)
202-
break
203-
candidate_path = os.path.dirname(candidate_path)
204-
elif item.endswith('pom.xml') and os.path.abspath(item) in modules_by_path:
205-
resolved.add(modules_by_path[os.path.abspath(item)].key)
210+
if os.path.isfile(abs_item) or (
211+
not item.endswith("pom.xml") and os.path.isdir(abs_item)
212+
):
213+
# Heuristic: if it's a file, find containing pom
214+
# if it's a dir, look for pom.xml inside or check if it matches a module path
215+
candidate_path = abs_item
216+
if os.path.isfile(candidate_path) and not candidate_path.endswith(
217+
"pom.xml"
218+
):
219+
candidate_path = os.path.dirname(candidate_path)
220+
221+
# Traverse up to find pom.xml
222+
while candidate_path.startswith(os.getcwd()) and len(candidate_path) >= len(
223+
os.getcwd()
224+
):
225+
pom_path = os.path.join(candidate_path, "pom.xml")
226+
if pom_path in modules_by_path:
227+
resolved.add(modules_by_path[pom_path].key)
228+
break
229+
candidate_path = os.path.dirname(candidate_path)
230+
elif item.endswith("pom.xml") and os.path.abspath(item) in modules_by_path:
231+
resolved.add(modules_by_path[os.path.abspath(item)].key)
206232
else:
207233
# Try to match simple name (artifactId) or groupId:artifactId
208234
found = False
209235
for key, module in modules_by_key.items():
210-
if item == module.artifactId or item == f"{module.group_id}:{module.artifactId}":
236+
if (
237+
item == module.artifactId
238+
or item == f"{module.group_id}:{module.artifactId}"
239+
):
211240
resolved.add(key)
212241
found = True
213242
break
214243
if not found:
215-
print(f"Warning: Could not resolve input '{item}' to a module.", file=sys.stderr)
216-
244+
print(
245+
f"Warning: Could not resolve input '{item}' to a module.",
246+
file=sys.stderr,
247+
)
248+
217249
return list(resolved)
218250

251+
219252
def main():
220-
parser = argparse.ArgumentParser(description='Identify upstream dependencies for partial builds.')
221-
parser.add_argument('modules', nargs='+', help='List of modified modules or file paths')
253+
parser = argparse.ArgumentParser(
254+
description="Identify upstream dependencies for partial builds."
255+
)
256+
parser.add_argument(
257+
"modules", nargs="+", help="List of modified modules or file paths"
258+
)
222259
args = parser.parse_args()
223260

224261
root_dir = os.getcwd()
225262
modules_by_key, graph = build_graph(root_dir)
226263
modules_by_path = {m.path: m for m in modules_by_key.values()}
227264

228-
start_nodes = resolve_modules_from_inputs(args.modules, modules_by_path, modules_by_key)
229-
265+
start_nodes = resolve_modules_from_inputs(
266+
args.modules, modules_by_path, modules_by_key
267+
)
268+
230269
if not start_nodes:
231270
print("No valid modules found from input.", file=sys.stderr)
232271
return
233272

234273
# Get transitive upstream dependencies
235-
# We include the start nodes themselves in the output set if they are dependencies of other start nodes?
274+
# We include the start nodes themselves in the output set if they are dependencies of other start nodes?
236275
# Usually we want: Dependencies of (Start Nodes) NOT INCLUDING Start Nodes themselves, unless A depends on B and both are modified.
237276
# But for "installing dependencies", we generally want EVERYTHING upstream of the modified set.
238277
# If I modified A, and A depends on B, I want to install B.
239278
# If I modified A and B, and A depends on B, I want to install B (before A).
240279
# But usually the build system will build A and B if I say "build A and B".
241280
# The request is: "determine which modules will need to be compiled and installed to the local maven repository"
242281
# This implies we want the COMPLEMENT set of the modified modules, restricted to the upstream graph.
243-
282+
244283
all_dependencies = get_transitive_dependencies(start_nodes, graph)
245-
246-
# Filter out the start nodes themselves, because they are the "modified" ones
247-
# that will presumably be built by the test command itself?
248-
# Actually, usually 'dependencies.sh' installs everything needed for the tests.
249-
# If I am testing A, I need B installed.
250-
# If I am testing A and B, I need C (upstream of B) installed.
251-
# So I need (TransitiveClosure(StartNodes) - StartNodes).
252-
284+
253285
upstream_only = all_dependencies - set(start_nodes)
254-
255-
# Map back to paths or artifactIds?
256-
# `mvn -pl` takes directory paths or [groupId]:artifactId
257-
# Directory paths are safer if artifactIds are not unique (rare but possible)
258-
# relpath is good.
259-
286+
287+
# Topological sort for installation order
288+
# (Install dependencies before dependents)
289+
sorted_upstream = []
290+
visited_sort = set()
291+
292+
def visit(node):
293+
if node in visited_sort:
294+
return
295+
visited_sort.add(node)
296+
# Visit dependencies first
297+
if node in graph:
298+
for dep in graph[node]:
299+
if dep in upstream_only:
300+
visit(dep)
301+
302+
sorted_upstream.append(node)
303+
304+
for node in upstream_only:
305+
visit(node)
306+
260307
results = []
261-
for key in upstream_only:
308+
for key in sorted_upstream:
262309
module = modules_by_key[key]
263310
rel_path = os.path.relpath(os.path.dirname(module.path), root_dir)
264-
# Maven -pl expects project dir or group:artifact
265311
results.append(rel_path)
266-
267-
print(','.join(sorted(results)))
268312

269-
if __name__ == '__main__':
313+
print(",".join(results))
314+
315+
316+
if __name__ == "__main__":
270317
main()

0 commit comments

Comments
 (0)