Skip to content

Commit 65a4e4b

Browse files
Merge pull request #372 from KernelTuner/amdsmi-observer
Add `AMDSMIObserver` that uses `amdsmi` to measure energy
2 parents 2ea380b + c5fd7c4 commit 65a4e4b

1 file changed

Lines changed: 243 additions & 0 deletions

File tree

kernel_tuner/observers/amd.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import amdsmi
2+
import logging
3+
import numpy as np
4+
import time
5+
6+
from uuid import UUID
7+
from kernel_tuner.observers import BenchmarkObserver
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def _find_device_by_uuid(devices, hip_uuid):
13+
result = None
14+
15+
# Missing input
16+
if hip_uuid is None:
17+
return None
18+
19+
# HIP UUID has a strange encoding: https://github.com/ROCm/ROCm/issues/1642
20+
try:
21+
hip_hex = UUID(hex=hip_uuid).bytes.decode("ascii")
22+
except (UnicodeDecodeError, ValueError):
23+
hip_hex = str(hip_uuid)
24+
25+
for index, device in enumerate(devices):
26+
smi_uuid = str(amdsmi.amdsmi_get_gpu_device_uuid(device))
27+
28+
# For some reason, only the last 12 bytes match
29+
if hip_hex[-12:] != smi_uuid[-12:]:
30+
continue
31+
32+
# Multiple devices have same UUID?
33+
if result is not None:
34+
logger.warning(f"could not find device: multiple have UUID {hip_uuid}")
35+
return None
36+
37+
result = index
38+
39+
return result
40+
41+
42+
def _find_device_by_bdf(devices, pci_domain, pci_bus, pci_device):
43+
result = None
44+
45+
# Missing input
46+
if pci_domain is None or pci_bus is None or pci_device is None:
47+
return None
48+
49+
for index, device in enumerate(devices):
50+
bdf = amdsmi.amdsmi_get_gpu_bdf_id(device)
51+
x = (bdf >> 32) & 0xFFFFFFFF
52+
y = (bdf >> 8) & 0xFF
53+
z = (bdf >> 3) & 0x1F
54+
55+
if (x, y, z) != (pci_domain, pci_bus, pci_device):
56+
continue
57+
58+
if result is not None:
59+
msg = f"domain {pci_domain}, bus {pci_bus}, device {pci_device}"
60+
logger.warning(f"could not find device: multiple have PCI {msg}")
61+
return None
62+
63+
result = index
64+
65+
return result
66+
67+
68+
SUPPORTED_OBSERVABLES = [
69+
"energy",
70+
"core_freq",
71+
"mem_freq",
72+
"temperature",
73+
"core_voltage",
74+
]
75+
76+
77+
class AMDSMIObserver(BenchmarkObserver):
78+
"""
79+
BenchmarkObserver that uses amdsmi to monitor AMD GPUs and measure energy usage (`energy`),
80+
core clock frequency (`core_freq`), memory clock frequency (`mem_freq`), temperature (`temperature`),
81+
and core voltage (`core_voltage`).
82+
"""
83+
84+
def __init__(self, observables=["energy"], *, device_id=None, prefix="amdsmi"):
85+
"""
86+
Initialize the AMDSMIObserver.
87+
88+
Supported observables are: `energy`, `core_freq`, `mem_freq`, `temperature`, and `core_voltage`.
89+
90+
:param observables: List of metrics to monitor. Defaults to just energy.
91+
:param device_id: Specific AMD device index. If None, auto-detection is used.
92+
:param prefix: Prefix used for name in the metrics. Defaults to "amdsmi".
93+
"""
94+
for obs in observables:
95+
if obs not in SUPPORTED_OBSERVABLES:
96+
raise ValueError(f"Observable {obs} not supported: {SUPPORTED_OBSERVABLES}")
97+
98+
self.observables = set(observables)
99+
self.iteration_results = {k: [] for k in self.observables}
100+
self.prefix = prefix
101+
self.device_id = device_id
102+
self.device = None
103+
104+
def register_device(self, dev):
105+
amdsmi.amdsmi_init()
106+
devices = amdsmi.amdsmi_get_processor_handles()
107+
108+
env = getattr(dev, "env", dict())
109+
110+
# Try to find by UUID
111+
uuid = env.get("uuid")
112+
uuid_idx = _find_device_by_uuid(devices, uuid)
113+
114+
# Try to find by PCI information
115+
pci_domain = env.get("pci_domain_id")
116+
pci_bus = env.get("pci_bus_id")
117+
pci_device = env.get("pci_device_id")
118+
pci_idx = _find_device_by_bdf(devices, pci_domain, pci_bus, pci_device)
119+
120+
bdf = f"domain {pci_domain}, bus {pci_bus}, device {pci_device}"
121+
122+
# If no device id is specified by user, get it from the UUID and PCI
123+
if self.device_id is None:
124+
if uuid_idx is None:
125+
raise ValueError(f"failed to detect AMD device: invalid UUID of backend: {uuid}")
126+
127+
if pci_idx is None:
128+
raise ValueError(
129+
f"failed to detect AMD device: invalid PCI information of backend: {bdf}"
130+
)
131+
132+
if uuid_idx != pci_idx:
133+
raise ValueError(
134+
"failed to detect AMD device: UUID and PCI information are inconsistent"
135+
)
136+
137+
self.device_id = uuid_idx
138+
logger.info(f"selected AMDSMI device {self.device_id}")
139+
140+
# Warn if UUID wants a different device
141+
if uuid_idx is not None and self.device_id != uuid_idx:
142+
logger.warning(
143+
f"specified device has mismatching UUID ({uuid}): {uuid_idx} != {self.device_id}"
144+
)
145+
146+
# Warn if PCI wants a different device
147+
if pci_idx is not None and self.device_id != pci_idx:
148+
logger.warning(
149+
f"specified device has mismatching PCI ({bdf}): {pci_idx} != {self.device_id}"
150+
)
151+
152+
if not (0 <= self.device_id < len(devices)):
153+
raise ValueError(
154+
f"invalid AMD SMI device_id {self.device_id}, found {len(devices)} devices"
155+
)
156+
157+
self.device = devices[self.device_id]
158+
159+
def after_start(self):
160+
self.energy_after_start = amdsmi.amdsmi_get_energy_count(self.device)
161+
self.during_timestamps = []
162+
self.during_results = {k: [] for k in self.observables if k != "energy"}
163+
self.during()
164+
165+
def during(self):
166+
# Get the current timestamp for measurements
167+
self.during_timestamps.append(time.perf_counter())
168+
169+
if "core_voltage" in self.observables:
170+
milli_volt = amdsmi.amdsmi_get_gpu_volt_metric(
171+
self.device,
172+
amdsmi.AmdSmiVoltageType.VDDGFX,
173+
amdsmi.AmdSmiVoltageMetric.CURRENT,
174+
)
175+
176+
# milli * 1-e3 -> volt
177+
self.during_results["core_voltage"].append(milli_volt * 1e-3)
178+
179+
if "core_freq" in self.observables:
180+
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.GFX)
181+
freq = obj["frequency"][obj["current"]]
182+
self.during_results["core_freq"].append(freq)
183+
184+
if "mem_freq" in self.observables:
185+
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.MEM)
186+
freq = obj["frequency"][obj["current"]]
187+
self.during_results["mem_freq"].append(freq)
188+
189+
if "temperature" in self.observables:
190+
temp = amdsmi.amdsmi_get_temp_metric(
191+
self.device,
192+
amdsmi.AmdSmiTemperatureType.HOTSPOT,
193+
amdsmi.AmdSmiTemperatureMetric.CURRENT,
194+
)
195+
196+
self.during_results["temperature"].append(temp)
197+
198+
def after_finish(self):
199+
self.during()
200+
201+
# Energy is an exception as it does not need integration over time
202+
if "energy" in self.observables:
203+
before = self.energy_after_start
204+
after = amdsmi.amdsmi_get_energy_count(self.device)
205+
206+
# This field changed names in rocm 6.4
207+
if "energy_accumulator" in before:
208+
energy_field = "energy_accumulator"
209+
elif "power" in before:
210+
energy_field = "power"
211+
else:
212+
raise RuntimeError(f"invalid result from amdsmi_get_energy_count: {before}")
213+
214+
diff = np.uint64(after[energy_field]) - np.uint64(before[energy_field])
215+
resolution = before["counter_resolution"]
216+
energy_mj = float(diff) * float(resolution)
217+
218+
# microJ * 1e-6 -> J
219+
self.iteration_results["energy"].append(energy_mj * 1e-6)
220+
221+
# For the others, we integrate over time and take the average
222+
x = self.during_timestamps
223+
for key, values in self.during_results.items():
224+
# np.trapezoid was np.trapz in older versions of np
225+
avg = np.trapezoid(values, x) / np.ptp(x)
226+
self.iteration_results[key].append(avg)
227+
228+
def get_results(self):
229+
results = dict()
230+
231+
for key in list(self.iteration_results):
232+
# Average of results at each iteration
233+
avg = np.average(self.iteration_results[key])
234+
235+
# Reset to empty
236+
self.iteration_results[key] = []
237+
238+
if self.prefix:
239+
results[f"{self.prefix}_{key}"] = avg
240+
else:
241+
results[key] = avg
242+
243+
return results

0 commit comments

Comments
 (0)