Skip to content

Commit 0d08734

Browse files
committed
build: add script to determine interdependencies between modules in the monorepo
1 parent 1115222 commit 0d08734

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed

.kokoro/determine_dependencies.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2026 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
import sys
19+
import xml.etree.ElementTree as ET
20+
from collections import defaultdict
21+
from typing import Dict, List, Set, Tuple
22+
23+
# Maven XML namespace
24+
NS = {'mvn': 'http://maven.apache.org/POM/4.0.0'}
25+
26+
class Module:
27+
def __init__(self, path: str, group_id: str, artifact_id: str, parent: Tuple[str, str] = None):
28+
self.path = path
29+
self.group_id = group_id
30+
self.artifactId = artifact_id
31+
self.parent = parent
32+
self.dependencies: Set[Tuple[str, str]] = set()
33+
34+
@property
35+
def key(self) -> Tuple[str, str]:
36+
return (self.group_id, self.artifactId)
37+
38+
def __repr__(self):
39+
return f"{self.group_id}:{self.artifactId}"
40+
41+
def parse_pom(path: str) -> Module:
42+
try:
43+
tree = ET.parse(path)
44+
root = tree.getroot()
45+
except ET.ParseError as e:
46+
print(f"Error parsing {path}: {e}", file=sys.stderr)
47+
return None
48+
49+
# Handle namespace if present
50+
# XML tags in ElementTree are {namespace}tag
51+
# We'll use find with namespaces for robustness, but simple logic for extraction
52+
53+
# Helper to clean tag name
54+
def local_name(tag):
55+
if '}' in tag:
56+
return tag.split('}', 1)[1]
57+
return tag
58+
59+
parent_elem = root.find('mvn:parent', NS)
60+
parent_coords = None
61+
parent_group_id = None
62+
if parent_elem is not None:
63+
p_group = parent_elem.find('mvn:groupId', NS).text
64+
p_artifact = parent_elem.find('mvn:artifactId', NS).text
65+
parent_coords = (p_group, p_artifact)
66+
parent_group_id = p_group
67+
68+
group_id_elem = root.find('mvn:groupId', NS)
69+
# Inherit groupId from parent if not specified
70+
if group_id_elem is not None:
71+
group_id = group_id_elem.text
72+
elif parent_group_id:
73+
group_id = parent_group_id
74+
else:
75+
# Fallback or error? For now, use artifactId as heuristic or empty
76+
group_id = "unknown"
77+
78+
artifact_id = root.find('mvn:artifactId', NS).text
79+
80+
module = Module(path, group_id, artifact_id, parent_coords)
81+
82+
# Dependencies
83+
def add_dependencies(section):
84+
if section is not None:
85+
for dep in section.findall('mvn:dependency', NS):
86+
d_group = dep.find('mvn:groupId', NS)
87+
d_artifact = dep.find('mvn:artifactId', NS)
88+
if d_group is not None and d_artifact is not None:
89+
module.dependencies.add((d_group.text, d_artifact.text))
90+
91+
add_dependencies(root.find('mvn:dependencies', NS))
92+
93+
dep_mgmt = root.find('mvn:dependencyManagement', NS)
94+
if dep_mgmt is not None:
95+
add_dependencies(dep_mgmt.find('mvn:dependencies', NS))
96+
97+
# Plugin dependencies
98+
build = root.find('mvn:build', NS)
99+
if build is not None:
100+
plugins = build.find('mvn:plugins', NS)
101+
if plugins is not None:
102+
for plugin in plugins.findall('mvn:plugin', NS):
103+
# Plugin itself
104+
p_group = plugin.find('mvn:groupId', NS)
105+
p_artifact = plugin.find('mvn:artifactId', NS)
106+
if p_group is not None and p_artifact is not None:
107+
module.dependencies.add((p_group.text, p_artifact.text))
108+
109+
# Plugin dependencies
110+
add_dependencies(plugin.find('mvn:dependencies', NS))
111+
112+
# Plugin Management
113+
plugin_mgmt = build.find('mvn:pluginManagement', NS)
114+
if plugin_mgmt is not None:
115+
plugins = plugin_mgmt.find('mvn:plugins', NS)
116+
if plugins is not None:
117+
for plugin in plugins.findall('mvn:plugin', NS):
118+
# Plugin itself
119+
p_group = plugin.find('mvn:groupId', NS)
120+
p_artifact = plugin.find('mvn:artifactId', NS)
121+
if p_group is not None and p_artifact is not None:
122+
module.dependencies.add((p_group.text, p_artifact.text))
123+
124+
add_dependencies(plugin.find('mvn:dependencies', NS))
125+
126+
return module
127+
128+
def find_poms(root_dir: str) -> List[str]:
129+
pom_files = []
130+
for dirpath, dirnames, filenames in os.walk(root_dir):
131+
# Skip hidden directories and known non-module dirs
132+
dirnames[:] = [d for d in dirnames if not d.startswith('.')]
133+
134+
if 'pom.xml' in filenames:
135+
pom_files.append(os.path.join(dirpath, 'pom.xml'))
136+
return pom_files
137+
138+
def build_graph(root_dir: str) -> Tuple[Dict[Tuple[str, str], Module], Dict[Tuple[str, str], Set[Tuple[str, str]]]]:
139+
pom_paths = find_poms(root_dir)
140+
modules: Dict[Tuple[str, str], Module] = {}
141+
142+
# First pass: load all modules
143+
for path in pom_paths:
144+
module = parse_pom(path)
145+
if module:
146+
modules[module.key] = module
147+
148+
# Build adjacency list: dependent -> dependencies (upstream)
149+
# Only include dependencies that are present in the repo
150+
graph: Dict[Tuple[str, str], Set[Tuple[str, str]]] = defaultdict(set)
151+
152+
for key, module in modules.items():
153+
# Parent dependency
154+
if module.parent and module.parent in modules:
155+
graph[key].add(module.parent)
156+
157+
# Regular dependencies
158+
for dep_key in module.dependencies:
159+
if dep_key in modules:
160+
graph[key].add(dep_key)
161+
162+
return modules, graph
163+
164+
def get_transitive_dependencies(
165+
start_nodes: List[Tuple[str, str]],
166+
graph: Dict[Tuple[str, str], Set[Tuple[str, str]]]
167+
) -> Set[Tuple[str, str]]:
168+
visited = set()
169+
stack = list(start_nodes)
170+
171+
while stack:
172+
node = stack.pop()
173+
if node not in visited:
174+
visited.add(node)
175+
# Add upstream dependencies to stack
176+
if node in graph:
177+
for upstream in graph[node]:
178+
if upstream not in visited:
179+
stack.append(upstream)
180+
181+
return visited
182+
183+
def resolve_modules_from_inputs(inputs: List[str], modules_by_path: Dict[str, Module], modules_by_key: Dict[Tuple[str, str], Module]) -> List[Tuple[str, str]]:
184+
resolved = set()
185+
for item in inputs:
186+
# Check if item is a path
187+
abs_item = os.path.abspath(item)
188+
189+
# If it's a file, try to find the nearest pom.xml
190+
if os.path.isfile(abs_item) or (not item.endswith('pom.xml') and os.path.isdir(abs_item)):
191+
# Heuristic: if it's a file, find containing pom
192+
# if it's a dir, look for pom.xml inside or check if it matches a module path
193+
candidate_path = abs_item
194+
if os.path.isfile(candidate_path) and not candidate_path.endswith('pom.xml'):
195+
candidate_path = os.path.dirname(candidate_path)
196+
197+
# Traverse up to find pom.xml
198+
while candidate_path.startswith(os.getcwd()) and len(candidate_path) >= len(os.getcwd()):
199+
pom_path = os.path.join(candidate_path, 'pom.xml')
200+
if pom_path in modules_by_path:
201+
resolved.add(modules_by_path[pom_path].key)
202+
break
203+
candidate_path = os.path.dirname(candidate_path)
204+
elif item.endswith('pom.xml') and os.path.abspath(item) in modules_by_path:
205+
resolved.add(modules_by_path[os.path.abspath(item)].key)
206+
else:
207+
# Try to match simple name (artifactId) or groupId:artifactId
208+
found = False
209+
for key, module in modules_by_key.items():
210+
if item == module.artifactId or item == f"{module.group_id}:{module.artifactId}":
211+
resolved.add(key)
212+
found = True
213+
break
214+
if not found:
215+
print(f"Warning: Could not resolve input '{item}' to a module.", file=sys.stderr)
216+
217+
return list(resolved)
218+
219+
def main():
220+
parser = argparse.ArgumentParser(description='Identify upstream dependencies for partial builds.')
221+
parser.add_argument('modules', nargs='+', help='List of modified modules or file paths')
222+
args = parser.parse_args()
223+
224+
root_dir = os.getcwd()
225+
modules_by_key, graph = build_graph(root_dir)
226+
modules_by_path = {m.path: m for m in modules_by_key.values()}
227+
228+
start_nodes = resolve_modules_from_inputs(args.modules, modules_by_path, modules_by_key)
229+
230+
if not start_nodes:
231+
print("No valid modules found from input.", file=sys.stderr)
232+
return
233+
234+
# Get transitive upstream dependencies
235+
# We include the start nodes themselves in the output set if they are dependencies of other start nodes?
236+
# Usually we want: Dependencies of (Start Nodes) NOT INCLUDING Start Nodes themselves, unless A depends on B and both are modified.
237+
# But for "installing dependencies", we generally want EVERYTHING upstream of the modified set.
238+
# If I modified A, and A depends on B, I want to install B.
239+
# If I modified A and B, and A depends on B, I want to install B (before A).
240+
# But usually the build system will build A and B if I say "build A and B".
241+
# The request is: "determine which modules will need to be compiled and installed to the local maven repository"
242+
# This implies we want the COMPLEMENT set of the modified modules, restricted to the upstream graph.
243+
244+
all_dependencies = get_transitive_dependencies(start_nodes, graph)
245+
246+
# Filter out the start nodes themselves, because they are the "modified" ones
247+
# that will presumably be built by the test command itself?
248+
# Actually, usually 'dependencies.sh' installs everything needed for the tests.
249+
# If I am testing A, I need B installed.
250+
# If I am testing A and B, I need C (upstream of B) installed.
251+
# So I need (TransitiveClosure(StartNodes) - StartNodes).
252+
253+
upstream_only = all_dependencies - set(start_nodes)
254+
255+
# Map back to paths or artifactIds?
256+
# `mvn -pl` takes directory paths or [groupId]:artifactId
257+
# Directory paths are safer if artifactIds are not unique (rare but possible)
258+
# relpath is good.
259+
260+
results = []
261+
for key in upstream_only:
262+
module = modules_by_key[key]
263+
rel_path = os.path.relpath(os.path.dirname(module.path), root_dir)
264+
# Maven -pl expects project dir or group:artifact
265+
results.append(rel_path)
266+
267+
print(','.join(sorted(results)))
268+
269+
if __name__ == '__main__':
270+
main()

0 commit comments

Comments
 (0)