1616"""
1717
1818import json
19+ import os
1920import re
2021import sys
2122from collections import defaultdict
2223from datetime import datetime
24+ from pathlib import Path
2325
2426try :
2527 from jinja2 import Environment , FileSystemLoader
@@ -34,12 +36,13 @@ def parse_arguments():
3436 """Parse command-line arguments."""
3537 if len (sys .argv ) < 7 :
3638 print (
37- "Usage: analyze_build_trace.py <trace_file > <output_file> <target> <granularity> <build_time> <template_dir>"
39+ "Usage: analyze_build_trace.py <trace_files_or_dir > <output_file> <target> <granularity> <build_time> <template_dir>"
3840 )
41+ print (" trace_files_or_dir: Comma-separated list of trace files OR directory containing .json files" )
3942 sys .exit (1 )
4043
4144 return {
42- "trace_file " : sys .argv [1 ],
45+ "trace_input " : sys .argv [1 ],
4346 "output_file" : sys .argv [2 ],
4447 "target" : sys .argv [3 ],
4548 "granularity" : sys .argv [4 ],
@@ -48,53 +51,126 @@ def parse_arguments():
4851 }
4952
5053
51- def load_trace_data (trace_file ):
52- """Load and parse the trace JSON file."""
53- print (f"Loading trace file: { trace_file } " )
54- with open (trace_file , "r" ) as f :
55- return json .load (f )
54+ def find_trace_files (trace_input ):
55+ """Find all trace files from input (file list, single file, or directory)."""
56+ trace_files = []
57+
58+ # Check if it's a directory
59+ if os .path .isdir (trace_input ):
60+ print (f"Scanning directory: { trace_input } " )
61+ for root , dirs , files in os .walk (trace_input ):
62+ for file in files :
63+ # Include .cpp.json and .hip.json, exclude compile_commands.json and CMake files
64+ if file .endswith (('.cpp.json' , '.hip.json' )) and 'CMakeFiles' in root :
65+ trace_files .append (os .path .join (root , file ))
66+ trace_files .sort ()
67+ # Check if it's a comma-separated list
68+ elif ',' in trace_input :
69+ trace_files = [f .strip () for f in trace_input .split (',' )]
70+ # Single file
71+ else :
72+ trace_files = [trace_input ]
73+
74+ # Filter out non-existent files
75+ valid_files = [f for f in trace_files if os .path .isfile (f )]
76+
77+ if not valid_files :
78+ print (f"Error: No valid trace files found in: { trace_input } " , file = sys .stderr )
79+ sys .exit (1 )
80+
81+ print (f"Found { len (valid_files )} trace file(s)" )
82+ return valid_files
83+
84+
85+ def load_trace_data (trace_files ):
86+ """Load and parse multiple trace JSON files."""
87+ all_data = []
88+
89+ for trace_file in trace_files :
90+ print (f" Loading: { trace_file } " )
91+ try :
92+ with open (trace_file , "r" ) as f :
93+ data = json .load (f )
94+ # Get file basename for tracking
95+ file_name = os .path .basename (trace_file )
96+ all_data .append ({
97+ 'file' : file_name ,
98+ 'path' : trace_file ,
99+ 'data' : data
100+ })
101+ except Exception as e :
102+ print (f" Warning: Failed to load { trace_file } : { e } " , file = sys .stderr )
103+
104+ return all_data
56105
57106
58- def process_events (data ):
59- """Process trace events and extract template instantiation statistics."""
60- print ("Processing events..." )
107+ def process_events (all_trace_data ):
108+ """Process trace events from multiple files and extract statistics."""
109+ print ("Processing events from all files ..." )
61110
62111 template_stats = defaultdict (lambda : {"count" : 0 , "total_dur" : 0 })
63112 phase_stats = defaultdict (int )
64113 top_individual = []
114+ file_stats = []
115+ total_events = 0
116+
117+ for trace_info in all_trace_data :
118+ file_name = trace_info ['file' ]
119+ data = trace_info ['data' ]
120+ events = data .get ("traceEvents" , [])
121+
122+ file_template_time = 0
123+ file_event_count = len (events )
124+ total_events += file_event_count
125+
126+ print (f" Processing { file_name } : { file_event_count :,} events" )
127+
128+ for event in events :
129+ name = event .get ("name" , "" )
130+ dur = int (event .get ("dur" , 0 )) # Keep as integer microseconds
131+
132+ if name and dur > 0 :
133+ phase_stats [name ] += dur
65134
66- for event in data .get ("traceEvents" , []):
67- name = event .get ("name" , "" )
68- dur = int (event .get ("dur" , 0 )) # Keep as integer microseconds
135+ if name in ["InstantiateFunction" , "InstantiateClass" ]:
136+ detail = event .get ("args" , {}).get ("detail" , "" )
137+ top_individual .append ({
138+ "detail" : detail ,
139+ "dur" : dur ,
140+ "type" : name ,
141+ "file" : file_name
142+ })
69143
70- if name and dur > 0 :
71- phase_stats [name ] += dur
144+ file_template_time += dur
72145
73- if name in ["InstantiateFunction" , "InstantiateClass" ]:
74- detail = event .get ("args" , {}).get ("detail" , "" )
75- top_individual .append ({"detail" : detail , "dur" : dur , "type" : name })
146+ # Extract template name (everything before '<' or '(')
147+ match = re .match (r"^([^<(]+)" , detail )
148+ if match :
149+ template_name = match .group (1 ).strip ()
150+ # Normalize template names
151+ template_name = re .sub (r"^ck::" , "" , template_name )
152+ template_name = re .sub (r"^std::" , "std::" , template_name )
76153
77- # Extract template name (everything before '<' or '(')
78- match = re .match (r"^([^<(]+)" , detail )
79- if match :
80- template_name = match .group (1 ).strip ()
81- # Normalize template names
82- template_name = re .sub (r"^ck::" , "" , template_name )
83- template_name = re .sub (r"^std::" , "std::" , template_name )
154+ template_stats [template_name ]["count" ] += 1
155+ template_stats [template_name ]["total_dur" ] += dur
84156
85- template_stats [template_name ]["count" ] += 1
86- template_stats [template_name ]["total_dur" ] += dur
157+ file_stats .append ({
158+ 'name' : file_name ,
159+ 'events' : file_event_count ,
160+ 'template_time' : file_template_time
161+ })
87162
88- return template_stats , phase_stats , top_individual
163+ return template_stats , phase_stats , top_individual , file_stats , total_events
89164
90165
91- def prepare_template_data (template_stats , phase_stats , top_individual ):
166+ def prepare_template_data (template_stats , phase_stats , top_individual , file_stats ):
92167 """Prepare and calculate derived statistics for template rendering."""
93168 print ("Sorting data..." )
94169
95170 # Sort data
96171 sorted_phases = sorted (phase_stats .items (), key = lambda x : x [1 ], reverse = True )
97172 top_individual .sort (key = lambda x : x ["dur" ], reverse = True )
173+ file_stats .sort (key = lambda x : x ["template_time" ], reverse = True )
98174
99175 # Calculate totals
100176 total_template_time = sum (s ["total_dur" ] for s in template_stats .values ())
@@ -170,6 +246,7 @@ def prepare_template_data(template_stats, phase_stats, top_individual):
170246 "median_count" : median_count ,
171247 "top10_pct" : top10_pct ,
172248 "unique_families" : len (template_stats ),
249+ "file_stats" : file_stats ,
173250 }
174251
175252
@@ -208,7 +285,7 @@ def us_to_s(value):
208285 return env
209286
210287
211- def generate_report (env , data , args , total_events ):
288+ def generate_report (env , data , args , total_events , num_files ):
212289 """Generate the final report using Jinja2 template."""
213290 print ("Rendering report with Jinja2..." )
214291
@@ -220,6 +297,7 @@ def generate_report(env, data, args, total_events):
220297 granularity = args ["granularity" ],
221298 build_time = args ["build_time" ],
222299 total_events = total_events ,
300+ num_files = num_files ,
223301 total_instantiations = data ["total_inst" ],
224302 unique_families = data ["unique_families" ],
225303 total_trace_time = data ["total_trace_time" ],
@@ -230,6 +308,7 @@ def generate_report(env, data, args, total_events):
230308 templates_by_count = data ["templates_by_count" ],
231309 median_count = data ["median_count" ],
232310 top10_pct = data ["top10_pct" ],
311+ file_stats = data ["file_stats" ],
233312 )
234313
235314 return report_content
@@ -239,28 +318,29 @@ def main():
239318 """Main entry point for the analysis tool."""
240319 args = parse_arguments ()
241320
242- # Load trace data
243- trace_data = load_trace_data (args ["trace_file " ])
244- total_events = len ( trace_data . get ( "traceEvents" , []) )
321+ # Find and load trace files
322+ trace_files = find_trace_files (args ["trace_input " ])
323+ all_trace_data = load_trace_data ( trace_files )
245324
246- # Process events
247- template_stats , phase_stats , top_individual = process_events (trace_data )
325+ # Process events from all files
326+ template_stats , phase_stats , top_individual , file_stats , total_events = process_events (all_trace_data )
248327
249328 # Prepare template data
250- data = prepare_template_data (template_stats , phase_stats , top_individual )
329+ data = prepare_template_data (template_stats , phase_stats , top_individual , file_stats )
251330
252331 # Setup Jinja2 environment
253332 env = setup_jinja_environment (args ["template_dir" ])
254333
255334 # Generate report
256- report_content = generate_report (env , data , args , total_events )
335+ report_content = generate_report (env , data , args , total_events , len ( all_trace_data ) )
257336
258337 # Write output
259338 with open (args ["output_file" ], "w" ) as f :
260339 f .write (report_content )
261340
262341 print (f"Report generated: { args ['output_file' ]} " )
263- print (f"Report size: { len (report_content )} bytes" )
342+ print (f"Report size: { len (report_content ):,} bytes" )
343+ print (f"Analyzed { len (all_trace_data )} file(s) with { total_events :,} total events" )
264344
265345
266346if __name__ == "__main__" :
0 commit comments