2121from typing import Dict , List , Set , Tuple
2222
2323# Maven XML namespace
24- NS = {'mvn' : 'http://maven.apache.org/POM/4.0.0' }
24+ NS = {"mvn" : "http://maven.apache.org/POM/4.0.0" }
25+
2526
2627class Module :
27- def __init__ (self , path : str , group_id : str , artifact_id : str , parent : Tuple [str , str ] = None ):
28+ def __init__ (
29+ self , path : str , group_id : str , artifact_id : str , parent : Tuple [str , str ] = None
30+ ):
2831 self .path = path
2932 self .group_id = group_id
3033 self .artifactId = artifact_id
@@ -38,6 +41,7 @@ def key(self) -> Tuple[str, str]:
3841 def __repr__ (self ):
3942 return f"{ self .group_id } :{ self .artifactId } "
4043
44+
4145def parse_pom (path : str ) -> Module :
4246 try :
4347 tree = ET .parse (path )
@@ -49,23 +53,23 @@ def parse_pom(path: str) -> Module:
4953 # Handle namespace if present
5054 # XML tags in ElementTree are {namespace}tag
5155 # We'll use find with namespaces for robustness, but simple logic for extraction
52-
56+
5357 # Helper to clean tag name
5458 def local_name (tag ):
55- if '}' in tag :
56- return tag .split ('}' , 1 )[1 ]
59+ if "}" in tag :
60+ return tag .split ("}" , 1 )[1 ]
5761 return tag
5862
59- parent_elem = root .find (' mvn:parent' , NS )
63+ parent_elem = root .find (" mvn:parent" , NS )
6064 parent_coords = None
6165 parent_group_id = None
6266 if parent_elem is not None :
63- p_group = parent_elem .find (' mvn:groupId' , NS ).text
64- p_artifact = parent_elem .find (' mvn:artifactId' , NS ).text
67+ p_group = parent_elem .find (" mvn:groupId" , NS ).text
68+ p_artifact = parent_elem .find (" mvn:artifactId" , NS ).text
6569 parent_coords = (p_group , p_artifact )
6670 parent_group_id = p_group
6771
68- group_id_elem = root .find (' mvn:groupId' , NS )
72+ group_id_elem = root .find (" mvn:groupId" , NS )
6973 # Inherit groupId from parent if not specified
7074 if group_id_elem is not None :
7175 group_id = group_id_elem .text
@@ -75,70 +79,80 @@ def local_name(tag):
7579 # Fallback or error? For now, use artifactId as heuristic or empty
7680 group_id = "unknown"
7781
78- artifact_id = root .find (' mvn:artifactId' , NS ).text
79-
82+ artifact_id = root .find (" mvn:artifactId" , NS ).text
83+
8084 module = Module (path , group_id , artifact_id , parent_coords )
8185
8286 # Dependencies
83- def add_dependencies (section ):
87+ def add_dependencies (section , is_management = False ):
8488 if section is not None :
85- for dep in section .findall ('mvn:dependency' , NS ):
86- d_group = dep .find ('mvn:groupId' , NS )
87- d_artifact = dep .find ('mvn:artifactId' , NS )
89+ for dep in section .findall ("mvn:dependency" , NS ):
90+ if is_management :
91+ # Only include 'import' scope dependencies from dependencyManagement
92+ scope = dep .find ("mvn:scope" , NS )
93+ if scope is None or scope .text != "import" :
94+ continue
95+
96+ d_group = dep .find ("mvn:groupId" , NS )
97+ d_artifact = dep .find ("mvn:artifactId" , NS )
8898 if d_group is not None and d_artifact is not None :
8999 module .dependencies .add ((d_group .text , d_artifact .text ))
90100
91- add_dependencies (root .find (' mvn:dependencies' , NS ))
92-
93- dep_mgmt = root .find (' mvn:dependencyManagement' , NS )
101+ add_dependencies (root .find (" mvn:dependencies" , NS ))
102+
103+ dep_mgmt = root .find (" mvn:dependencyManagement" , NS )
94104 if dep_mgmt is not None :
95- add_dependencies (dep_mgmt .find (' mvn:dependencies' , NS ))
105+ add_dependencies (dep_mgmt .find (" mvn:dependencies" , NS ), is_management = True )
96106
97107 # Plugin dependencies
98- build = root .find (' mvn:build' , NS )
108+ build = root .find (" mvn:build" , NS )
99109 if build is not None :
100- plugins = build .find (' mvn:plugins' , NS )
110+ plugins = build .find (" mvn:plugins" , NS )
101111 if plugins is not None :
102- for plugin in plugins .findall (' mvn:plugin' , NS ):
112+ for plugin in plugins .findall (" mvn:plugin" , NS ):
103113 # Plugin itself
104- p_group = plugin .find (' mvn:groupId' , NS )
105- p_artifact = plugin .find (' mvn:artifactId' , NS )
114+ p_group = plugin .find (" mvn:groupId" , NS )
115+ p_artifact = plugin .find (" mvn:artifactId" , NS )
106116 if p_group is not None and p_artifact is not None :
107117 module .dependencies .add ((p_group .text , p_artifact .text ))
108-
118+
109119 # Plugin dependencies
110- add_dependencies (plugin .find (' mvn:dependencies' , NS ))
111-
120+ add_dependencies (plugin .find (" mvn:dependencies" , NS ))
121+
112122 # Plugin Management
113- plugin_mgmt = build .find (' mvn:pluginManagement' , NS )
123+ plugin_mgmt = build .find (" mvn:pluginManagement" , NS )
114124 if plugin_mgmt is not None :
115- plugins = plugin_mgmt .find (' mvn:plugins' , NS )
116- if plugins is not None :
117- for plugin in plugins .findall (' mvn:plugin' , NS ):
125+ plugins = plugin_mgmt .find (" mvn:plugins" , NS )
126+ if plugins is not None :
127+ for plugin in plugins .findall (" mvn:plugin" , NS ):
118128 # Plugin itself
119- p_group = plugin .find (' mvn:groupId' , NS )
120- p_artifact = plugin .find (' mvn:artifactId' , NS )
129+ p_group = plugin .find (" mvn:groupId" , NS )
130+ p_artifact = plugin .find (" mvn:artifactId" , NS )
121131 if p_group is not None and p_artifact is not None :
122132 module .dependencies .add ((p_group .text , p_artifact .text ))
123133
124- add_dependencies (plugin .find (' mvn:dependencies' , NS ))
134+ add_dependencies (plugin .find (" mvn:dependencies" , NS ))
125135
126136 return module
127137
138+
128139def find_poms (root_dir : str ) -> List [str ]:
129140 pom_files = []
130141 for dirpath , dirnames , filenames in os .walk (root_dir ):
131142 # Skip hidden directories and known non-module dirs
132- dirnames [:] = [d for d in dirnames if not d .startswith ('.' )]
133-
134- if ' pom.xml' in filenames :
135- pom_files .append (os .path .join (dirpath , ' pom.xml' ))
143+ dirnames [:] = [d for d in dirnames if not d .startswith ("." )]
144+
145+ if " pom.xml" in filenames :
146+ pom_files .append (os .path .join (dirpath , " pom.xml" ))
136147 return pom_files
137148
138- def build_graph (root_dir : str ) -> Tuple [Dict [Tuple [str , str ], Module ], Dict [Tuple [str , str ], Set [Tuple [str , str ]]]]:
149+
150+ def build_graph (
151+ root_dir : str ,
152+ ) -> Tuple [Dict [Tuple [str , str ], Module ], Dict [Tuple [str , str ], Set [Tuple [str , str ]]]]:
139153 pom_paths = find_poms (root_dir )
140154 modules : Dict [Tuple [str , str ], Module ] = {}
141-
155+
142156 # First pass: load all modules
143157 for path in pom_paths :
144158 module = parse_pom (path )
@@ -148,26 +162,27 @@ def build_graph(root_dir: str) -> Tuple[Dict[Tuple[str, str], Module], Dict[Tupl
148162 # Build adjacency list: dependent -> dependencies (upstream)
149163 # Only include dependencies that are present in the repo
150164 graph : Dict [Tuple [str , str ], Set [Tuple [str , str ]]] = defaultdict (set )
151-
165+
152166 for key , module in modules .items ():
153167 # Parent dependency
154168 if module .parent and module .parent in modules :
155169 graph [key ].add (module .parent )
156-
170+
157171 # Regular dependencies
158172 for dep_key in module .dependencies :
159173 if dep_key in modules :
160174 graph [key ].add (dep_key )
161-
175+
162176 return modules , graph
163177
178+
164179def get_transitive_dependencies (
165- start_nodes : List [Tuple [str , str ]],
166- graph : Dict [Tuple [str , str ], Set [Tuple [str , str ]]]
180+ start_nodes : List [Tuple [str , str ]],
181+ graph : Dict [Tuple [str , str ], Set [Tuple [str , str ]]],
167182) -> Set [Tuple [str , str ]]:
168183 visited = set ()
169184 stack = list (start_nodes )
170-
185+
171186 while stack :
172187 node = stack .pop ()
173188 if node not in visited :
@@ -177,94 +192,126 @@ def get_transitive_dependencies(
177192 for upstream in graph [node ]:
178193 if upstream not in visited :
179194 stack .append (upstream )
180-
195+
181196 return visited
182197
183- def resolve_modules_from_inputs (inputs : List [str ], modules_by_path : Dict [str , Module ], modules_by_key : Dict [Tuple [str , str ], Module ]) -> List [Tuple [str , str ]]:
198+
199+ def resolve_modules_from_inputs (
200+ inputs : List [str ],
201+ modules_by_path : Dict [str , Module ],
202+ modules_by_key : Dict [Tuple [str , str ], Module ],
203+ ) -> List [Tuple [str , str ]]:
184204 resolved = set ()
185205 for item in inputs :
186206 # Check if item is a path
187207 abs_item = os .path .abspath (item )
188-
208+
189209 # If it's a file, try to find the nearest pom.xml
190- if os .path .isfile (abs_item ) or (not item .endswith ('pom.xml' ) and os .path .isdir (abs_item )):
191- # Heuristic: if it's a file, find containing pom
192- # if it's a dir, look for pom.xml inside or check if it matches a module path
193- candidate_path = abs_item
194- if os .path .isfile (candidate_path ) and not candidate_path .endswith ('pom.xml' ):
195- candidate_path = os .path .dirname (candidate_path )
196-
197- # Traverse up to find pom.xml
198- while candidate_path .startswith (os .getcwd ()) and len (candidate_path ) >= len (os .getcwd ()):
199- pom_path = os .path .join (candidate_path , 'pom.xml' )
200- if pom_path in modules_by_path :
201- resolved .add (modules_by_path [pom_path ].key )
202- break
203- candidate_path = os .path .dirname (candidate_path )
204- elif item .endswith ('pom.xml' ) and os .path .abspath (item ) in modules_by_path :
205- resolved .add (modules_by_path [os .path .abspath (item )].key )
210+ if os .path .isfile (abs_item ) or (
211+ not item .endswith ("pom.xml" ) and os .path .isdir (abs_item )
212+ ):
213+ # Heuristic: if it's a file, find containing pom
214+ # if it's a dir, look for pom.xml inside or check if it matches a module path
215+ candidate_path = abs_item
216+ if os .path .isfile (candidate_path ) and not candidate_path .endswith (
217+ "pom.xml"
218+ ):
219+ candidate_path = os .path .dirname (candidate_path )
220+
221+ # Traverse up to find pom.xml
222+ while candidate_path .startswith (os .getcwd ()) and len (candidate_path ) >= len (
223+ os .getcwd ()
224+ ):
225+ pom_path = os .path .join (candidate_path , "pom.xml" )
226+ if pom_path in modules_by_path :
227+ resolved .add (modules_by_path [pom_path ].key )
228+ break
229+ candidate_path = os .path .dirname (candidate_path )
230+ elif item .endswith ("pom.xml" ) and os .path .abspath (item ) in modules_by_path :
231+ resolved .add (modules_by_path [os .path .abspath (item )].key )
206232 else :
207233 # Try to match simple name (artifactId) or groupId:artifactId
208234 found = False
209235 for key , module in modules_by_key .items ():
210- if item == module .artifactId or item == f"{ module .group_id } :{ module .artifactId } " :
236+ if (
237+ item == module .artifactId
238+ or item == f"{ module .group_id } :{ module .artifactId } "
239+ ):
211240 resolved .add (key )
212241 found = True
213242 break
214243 if not found :
215- print (f"Warning: Could not resolve input '{ item } ' to a module." , file = sys .stderr )
216-
244+ print (
245+ f"Warning: Could not resolve input '{ item } ' to a module." ,
246+ file = sys .stderr ,
247+ )
248+
217249 return list (resolved )
218250
251+
219252def main ():
220- parser = argparse .ArgumentParser (description = 'Identify upstream dependencies for partial builds.' )
221- parser .add_argument ('modules' , nargs = '+' , help = 'List of modified modules or file paths' )
253+ parser = argparse .ArgumentParser (
254+ description = "Identify upstream dependencies for partial builds."
255+ )
256+ parser .add_argument (
257+ "modules" , nargs = "+" , help = "List of modified modules or file paths"
258+ )
222259 args = parser .parse_args ()
223260
224261 root_dir = os .getcwd ()
225262 modules_by_key , graph = build_graph (root_dir )
226263 modules_by_path = {m .path : m for m in modules_by_key .values ()}
227264
228- start_nodes = resolve_modules_from_inputs (args .modules , modules_by_path , modules_by_key )
229-
265+ start_nodes = resolve_modules_from_inputs (
266+ args .modules , modules_by_path , modules_by_key
267+ )
268+
230269 if not start_nodes :
231270 print ("No valid modules found from input." , file = sys .stderr )
232271 return
233272
234273 # Get transitive upstream dependencies
235- # We include the start nodes themselves in the output set if they are dependencies of other start nodes?
274+ # We include the start nodes themselves in the output set if they are dependencies of other start nodes?
236275 # Usually we want: Dependencies of (Start Nodes) NOT INCLUDING Start Nodes themselves, unless A depends on B and both are modified.
237276 # But for "installing dependencies", we generally want EVERYTHING upstream of the modified set.
238277 # If I modified A, and A depends on B, I want to install B.
239278 # If I modified A and B, and A depends on B, I want to install B (before A).
240279 # But usually the build system will build A and B if I say "build A and B".
241280 # The request is: "determine which modules will need to be compiled and installed to the local maven repository"
242281 # This implies we want the COMPLEMENT set of the modified modules, restricted to the upstream graph.
243-
282+
244283 all_dependencies = get_transitive_dependencies (start_nodes , graph )
245-
246- # Filter out the start nodes themselves, because they are the "modified" ones
247- # that will presumably be built by the test command itself?
248- # Actually, usually 'dependencies.sh' installs everything needed for the tests.
249- # If I am testing A, I need B installed.
250- # If I am testing A and B, I need C (upstream of B) installed.
251- # So I need (TransitiveClosure(StartNodes) - StartNodes).
252-
284+
253285 upstream_only = all_dependencies - set (start_nodes )
254-
255- # Map back to paths or artifactIds?
256- # `mvn -pl` takes directory paths or [groupId]:artifactId
257- # Directory paths are safer if artifactIds are not unique (rare but possible)
258- # relpath is good.
259-
286+
287+ # Topological sort for installation order
288+ # (Install dependencies before dependents)
289+ sorted_upstream = []
290+ visited_sort = set ()
291+
292+ def visit (node ):
293+ if node in visited_sort :
294+ return
295+ visited_sort .add (node )
296+ # Visit dependencies first
297+ if node in graph :
298+ for dep in graph [node ]:
299+ if dep in upstream_only :
300+ visit (dep )
301+
302+ sorted_upstream .append (node )
303+
304+ for node in upstream_only :
305+ visit (node )
306+
260307 results = []
261- for key in upstream_only :
308+ for key in sorted_upstream :
262309 module = modules_by_key [key ]
263310 rel_path = os .path .relpath (os .path .dirname (module .path ), root_dir )
264- # Maven -pl expects project dir or group:artifact
265311 results .append (rel_path )
266-
267- print (',' .join (sorted (results )))
268312
269- if __name__ == '__main__' :
313+ print ("," .join (results ))
314+
315+
316+ if __name__ == "__main__" :
270317 main ()
0 commit comments