99logger = logging .getLogger (__name__ )
1010
1111
12- def find_device_by_uuid (devices , hip_uuid ):
12+ def _find_device_by_uuid (devices , hip_uuid ):
1313 result = None
1414
1515 # Missing input
@@ -20,7 +20,7 @@ def find_device_by_uuid(devices, hip_uuid):
2020 try :
2121 hip_hex = UUID (hex = hip_uuid ).bytes .decode ("ascii" )
2222 except (UnicodeDecodeError , ValueError ):
23- return None
23+ hip_hex = str ( hip_uuid )
2424
2525 for index , device in enumerate (devices ):
2626 smi_uuid = str (amdsmi .amdsmi_get_gpu_device_uuid (device ))
@@ -39,7 +39,7 @@ def find_device_by_uuid(devices, hip_uuid):
3939 return result
4040
4141
42- def find_device_by_bdf (devices , pci_domain , pci_bus , pci_device ):
42+ def _find_device_by_bdf (devices , pci_domain , pci_bus , pci_device ):
4343 result = None
4444
4545 # Missing input
@@ -65,7 +65,13 @@ def find_device_by_bdf(devices, pci_domain, pci_bus, pci_device):
6565 return result
6666
6767
68- SUPPORTED_OBSERVABLES = ["energy" , "core_freq" , "mem_freq" , "temperature" , "core_voltage" ]
68+ SUPPORTED_OBSERVABLES = [
69+ "energy" ,
70+ "core_freq" ,
71+ "mem_freq" ,
72+ "temperature" ,
73+ "core_voltage" ,
74+ ]
6975
7076
7177class AMDSMIObserver (BenchmarkObserver ):
@@ -103,13 +109,13 @@ def register_device(self, dev):
103109
104110 # Try to find by UUID
105111 uuid = env .get ("uuid" )
106- uuid_idx = find_device_by_uuid (devices , uuid )
112+ uuid_idx = _find_device_by_uuid (devices , uuid )
107113
108114 # Try to find by PCI information
109115 pci_domain = env .get ("pci_domain_id" )
110116 pci_bus = env .get ("pci_bus_id" )
111117 pci_device = env .get ("pci_device_id" )
112- pci_idx = find_device_by_bdf (devices , pci_domain , pci_bus , pci_device )
118+ pci_idx = _find_device_by_bdf (devices , pci_domain , pci_bus , pci_device )
113119
114120 bdf = f"domain { pci_domain } , bus { pci_bus } , device { pci_device } "
115121
@@ -132,13 +138,13 @@ def register_device(self, dev):
132138 logger .info (f"selected AMDSMI device { self .device_id } " )
133139
134140 # Warn if UUID wants a different device
135- if self .device_id != uuid_idx :
141+ if uuid_idx is not None and self .device_id != uuid_idx :
136142 logger .warning (
137143 f"specified device has mismatching UUID ({ uuid } ): { uuid_idx } != { self .device_id } "
138144 )
139145
140146 # Warn if PCI wants a different device
141- if self .device_id != pci_idx :
147+ if pci_idx is not None and self .device_id != pci_idx :
142148 logger .warning (
143149 f"specified device has mismatching PCI ({ bdf } ): { pci_idx } != { self .device_id } "
144150 )
@@ -162,9 +168,13 @@ def during(self):
162168
163169 if "core_voltage" in self .observables :
164170 milli_volt = amdsmi .amdsmi_get_gpu_volt_metric (
165- self .device , amdsmi .AmdSmiVoltageType .VDDGFX , amdsmi .AmdSmiVoltageMetric .CURRENT
171+ self .device ,
172+ amdsmi .AmdSmiVoltageType .VDDGFX ,
173+ amdsmi .AmdSmiVoltageMetric .CURRENT ,
166174 )
167- self .during_results ["core_voltage" ].append (milli_volt * 1e-3 ) # milli -> volt
175+
176+ # milli * 1-e3 -> volt
177+ self .during_results ["core_voltage" ].append (milli_volt * 1e-3 )
168178
169179 if "core_freq" in self .observables :
170180 obj = amdsmi .amdsmi_get_clk_freq (self .device , amdsmi .AmdSmiClkType .GFX )
@@ -188,7 +198,7 @@ def during(self):
188198 def after_finish (self ):
189199 self .during ()
190200
191- # Energy is special as it does not need integration over time
201+ # Energy is an exception as it does not need integration over time
192202 if "energy" in self .observables :
193203 before = self .energy_after_start
194204 after = amdsmi .amdsmi_get_energy_count (self .device )
@@ -204,20 +214,26 @@ def after_finish(self):
204214 diff = np .uint64 (after [energy_field ]) - np .uint64 (before [energy_field ])
205215 resolution = before ["counter_resolution" ]
206216 energy_mj = float (diff ) * float (resolution )
207- self .iteration_results ["energy" ].append (energy_mj * 1e-6 ) # microJ -> J
217+
218+ # microJ * 1e-6 -> J
219+ self .iteration_results ["energy" ].append (energy_mj * 1e-6 )
208220
209221 # For the others, we integrate over time and take the average
222+ x = self .during_timestamps
210223 for key , values in self .during_results .items ():
211- x = self . during_timestamps
212- avg = np .trapezoid (values , x ) / np .ptp (x ) # np.trapz in older versions of np
224+ # np.trapezoid was np.trapz in older versions of np
225+ avg = np .trapezoid (values , x ) / np .ptp (x )
213226 self .iteration_results [key ].append (avg )
214227
215228 def get_results (self ):
216229 results = dict ()
217230
218231 for key in list (self .iteration_results ):
219- avg = np .average (self .iteration_results [key ]) # Average of results at each iteration
220- self .iteration_results [key ] = [] # Reset to empty
232+ # Average of results at each iteration
233+ avg = np .average (self .iteration_results [key ])
234+
235+ # Reset to empty
236+ self .iteration_results [key ] = []
221237
222238 if self .prefix :
223239 results [f"{ self .prefix } _{ key } " ] = avg
0 commit comments