Skip to content

Commit 22efdef

Browse files
committed
Add AMDSMIContinuousObserver
1 parent 8756e6d commit 22efdef

1 file changed

Lines changed: 236 additions & 78 deletions

File tree

kernel_tuner/observers/amd.py

Lines changed: 236 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33
import numpy as np
44
import time
55

6+
# Trapz was renamed to trapezoid in Numpy 2.0
7+
try:
8+
from numpy import trapezoid
9+
except ImportError:
10+
from numpy import trapz as trapezoid
11+
612
from uuid import UUID
7-
from kernel_tuner.observers import BenchmarkObserver
13+
from kernel_tuner.observers import BenchmarkObserver, ContinuousObserver
814

915
logger = logging.getLogger(__name__)
1016

@@ -65,23 +71,177 @@ def _find_device_by_bdf(devices, pci_domain, pci_bus, pci_device):
6571
return result
6672

6773

74+
class AMDDevice:
75+
def __init__(self, device):
76+
self.device = device
77+
78+
def total_energy_usage(self):
79+
"""Returns total energy usage since startup."""
80+
81+
result = amdsmi.amdsmi_get_energy_count(self.device)
82+
83+
# This field changed name in rocm 6.4
84+
if "energy_accumulator" not in result:
85+
if "power" in result:
86+
result["energy_accumulator"] = result["power"]
87+
else:
88+
raise RuntimeError(f"invalid result from amdsmi_get_energy_count: {result}")
89+
90+
return result
91+
92+
def current_power_usage(self):
93+
info = amdsmi.amdsmi_get_power_info(self.device)
94+
95+
if "current_socket_power" in info:
96+
# For newer Mi300+ cards
97+
return info["current_socket_power"]
98+
elif "average_socket_power" in info:
99+
# For older cards
100+
return info["average_socket_power"]
101+
else:
102+
raise RuntimeError(f"invalid result from amdsmi_get_power_info: {info}")
103+
104+
def core_voltage(self):
105+
"""Returns current voltage in Volt."""
106+
107+
milli_volt = amdsmi.amdsmi_get_gpu_volt_metric(
108+
self.device,
109+
amdsmi.AmdSmiVoltageType.VDDGFX,
110+
amdsmi.AmdSmiVoltageMetric.CURRENT,
111+
)
112+
113+
# milli * 1-e3 -> volt
114+
return milli_volt * 1e-3
115+
116+
def temperature(self):
117+
"""Returns current temperature in celcius."""
118+
return amdsmi.amdsmi_get_temp_metric(
119+
self.device,
120+
amdsmi.AmdSmiTemperatureType.HOTSPOT,
121+
amdsmi.AmdSmiTemperatureMetric.CURRENT,
122+
)
123+
124+
def mem_temperature(self):
125+
"""Returns current temperature in celcius."""
126+
return amdsmi.amdsmi_get_temp_metric(
127+
self.device,
128+
amdsmi.AmdSmiTemperatureType.VRAM,
129+
amdsmi.AmdSmiTemperatureMetric.CURRENT,
130+
)
131+
132+
def core_freq(self):
133+
"""Returns current core clock frequency in Hz."""
134+
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.GFX)
135+
freq = obj["frequency"][obj["current"]]
136+
return freq
137+
138+
def mem_freq(self):
139+
"""Returns current memory clock frequency in Hz."""
140+
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.MEM)
141+
freq = obj["frequency"][obj["current"]]
142+
return freq
143+
144+
def core_activity(self):
145+
"""Returns core usage as percentage (0-100)."""
146+
obj = amdsmi.amdsmi_get_gpu_activity(self.device)
147+
result = obj["gfx_activity"]
148+
# Result is "N/A" on error, return NaN instead
149+
return float("nan") if isinstance(result, str) else result
150+
151+
def mem_activity(self):
152+
"""Returns memory usage as percentage (0-100)."""
153+
obj = amdsmi.amdsmi_get_gpu_activity(self.device)
154+
result = obj["umc_activity"]
155+
# Result is "N/A" on error, return NaN instead
156+
return float("nan") if isinstance(result, str) else result
157+
158+
68159
SUPPORTED_OBSERVABLES = [
69160
"energy",
161+
"power",
70162
"core_freq",
71163
"mem_freq",
72164
"temperature",
165+
"mem_temperature",
73166
"core_voltage",
167+
"core_activity",
168+
"mem_activity",
74169
]
75170

76171

172+
class AMDSMIContinuousObserver(ContinuousObserver):
173+
def __init__(self, parent, continuous_duration=1.0):
174+
self.parent = parent
175+
self.continuous_duration = continuous_duration
176+
self.warmup_time = min(0.1, continuous_duration / 2)
177+
178+
# This assigned by Kernel Tuner's core
179+
self.results = None
180+
181+
def before_start(self):
182+
self.parent.before_start()
183+
184+
def after_start(self):
185+
self.warmup_completed = False
186+
self.start_time = time.perf_counter() + self.warmup_time
187+
188+
def during(self):
189+
now = time.perf_counter()
190+
191+
if not self.warmup_completed:
192+
if now < self.start_time:
193+
return
194+
195+
# Only call `after_start` once warmup time has passed
196+
self.start_time = now
197+
self.warmup_completed = True
198+
self.parent.after_start()
199+
200+
self.parent.during()
201+
202+
def after_finish(self):
203+
if self.warmup_completed:
204+
self.parent.after_finish()
205+
206+
def get_results(self):
207+
if not self.warmup_completed:
208+
return dict()
209+
210+
elapsed_sec = time.perf_counter() - self.start_time
211+
time_sec = self.results["time"] * 1e-3
212+
ratio = time_sec / elapsed_sec
213+
214+
# Get results from the parent
215+
results = self.parent.get_results()
216+
217+
# The energy field measures the energy over the entire
218+
# continuous duration. However, we want the average
219+
# energy usage _per_ kernel. To fix this, we multiply
220+
# by the ratio of elapsed time to time per kernel
221+
energy_field = self.parent.field_name("energy")
222+
223+
if energy_field in results:
224+
results[energy_field] = results[energy_field] * ratio
225+
226+
return results
227+
228+
77229
class AMDSMIObserver(BenchmarkObserver):
78230
"""
79231
BenchmarkObserver that uses amdsmi to monitor AMD GPUs and measure energy usage (`energy`),
80232
core clock frequency (`core_freq`), memory clock frequency (`mem_freq`), temperature (`temperature`),
81233
and core voltage (`core_voltage`).
82234
"""
83235

84-
def __init__(self, observables=["energy"], *, device_id=None, prefix="amdsmi"):
236+
def __init__(
237+
self,
238+
observables=["energy"],
239+
*,
240+
device_id=None,
241+
prefix="amdsmi",
242+
use_continuous_observer=True,
243+
continuous_duration=1.0,
244+
):
85245
"""
86246
Initialize the AMDSMIObserver.
87247
@@ -96,10 +256,12 @@ def __init__(self, observables=["energy"], *, device_id=None, prefix="amdsmi"):
96256
raise ValueError(f"Observable {obs} not supported: {SUPPORTED_OBSERVABLES}")
97257

98258
self.observables = set(observables)
99-
self.iteration_results = {k: [] for k in self.observables}
100259
self.prefix = prefix
101260
self.device_id = device_id
102261
self.device = None
262+
self.use_continuous_observer = use_continuous_observer
263+
self.continuous_duration = continuous_duration
264+
self.results_per_iteration = {self.field_name(k): [] for k in self.observables}
103265

104266
def register_device(self, dev):
105267
amdsmi.amdsmi_init()
@@ -125,119 +287,115 @@ def register_device(self, dev):
125287
raise ValueError(f"failed to detect AMD device: invalid UUID of backend: {uuid}")
126288

127289
if pci_idx is None:
128-
raise ValueError(
129-
f"failed to detect AMD device: invalid PCI information of backend: {bdf}"
130-
)
290+
raise ValueError(f"failed to detect AMD device: invalid PCI information of backend: {bdf}")
131291

132292
if uuid_idx != pci_idx:
133-
raise ValueError(
134-
"failed to detect AMD device: UUID and PCI information are inconsistent"
135-
)
293+
raise ValueError("failed to detect AMD device: UUID and PCI information are inconsistent")
136294

137295
self.device_id = uuid_idx
138296
logger.info(f"selected AMDSMI device {self.device_id}")
139297

140298
# Warn if UUID wants a different device
141299
if uuid_idx is not None and self.device_id != uuid_idx:
142-
logger.warning(
143-
f"specified device has mismatching UUID ({uuid}): {uuid_idx} != {self.device_id}"
144-
)
300+
logger.warning(f"specified device has mismatching UUID ({uuid}): {uuid_idx} != {self.device_id}")
145301

146302
# Warn if PCI wants a different device
147303
if pci_idx is not None and self.device_id != pci_idx:
148-
logger.warning(
149-
f"specified device has mismatching PCI ({bdf}): {pci_idx} != {self.device_id}"
150-
)
304+
logger.warning(f"specified device has mismatching PCI ({bdf}): {pci_idx} != {self.device_id}")
151305

152306
if not (0 <= self.device_id < len(devices)):
153-
raise ValueError(
154-
f"invalid AMD SMI device_id {self.device_id}, found {len(devices)} devices"
155-
)
307+
raise ValueError(f"invalid AMD SMI device_id {self.device_id}, found {len(devices)} devices")
308+
309+
self.device = AMDDevice(devices[self.device_id])
156310

157-
self.device = devices[self.device_id]
311+
if self.use_continuous_observer:
312+
self.continuous_observer = AMDSMIContinuousObserver(self, continuous_duration=self.continuous_duration)
158313

159314
def after_start(self):
160-
self.energy_after_start = amdsmi.amdsmi_get_energy_count(self.device)
161-
self.during_timestamps = []
162-
self.during_results = {k: [] for k in self.observables if k != "energy"}
163-
self.during()
315+
self.energy_after_start = self.device.total_energy_usage()
316+
self.sample_timestamps = []
317+
self.sample_values = {k: [] for k in self.results_per_iteration}
318+
self.sample_metrics()
164319

165320
def during(self):
166-
# Get the current timestamp for measurements
167-
self.during_timestamps.append(time.perf_counter())
321+
self.sample_metrics()
168322

169-
if "core_voltage" in self.observables:
170-
milli_volt = amdsmi.amdsmi_get_gpu_volt_metric(
171-
self.device,
172-
amdsmi.AmdSmiVoltageType.VDDGFX,
173-
amdsmi.AmdSmiVoltageMetric.CURRENT,
174-
)
323+
def field_name(self, name):
324+
if self.prefix:
325+
return f"{self.prefix}_{name}"
326+
else:
327+
return name
328+
329+
def store_sample(self, name, value):
330+
self.sample_values[self.field_name(name)].append(value)
175331

176-
# milli * 1-e3 -> volt
177-
self.during_results["core_voltage"].append(milli_volt * 1e-3)
332+
def sample_metrics(self):
333+
self.sample_timestamps.append(time.perf_counter())
334+
335+
if "core_voltage" in self.observables:
336+
self.store_sample("core_voltage", self.device.core_voltage())
178337

179338
if "core_freq" in self.observables:
180-
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.GFX)
181-
freq = obj["frequency"][obj["current"]]
182-
self.during_results["core_freq"].append(freq)
339+
self.store_sample("core_freq", self.device.core_freq())
183340

184341
if "mem_freq" in self.observables:
185-
obj = amdsmi.amdsmi_get_clk_freq(self.device, amdsmi.AmdSmiClkType.MEM)
186-
freq = obj["frequency"][obj["current"]]
187-
self.during_results["mem_freq"].append(freq)
342+
self.store_sample("mem_freq", self.device.mem_freq())
188343

189344
if "temperature" in self.observables:
190-
temp = amdsmi.amdsmi_get_temp_metric(
191-
self.device,
192-
amdsmi.AmdSmiTemperatureType.HOTSPOT,
193-
amdsmi.AmdSmiTemperatureMetric.CURRENT,
194-
)
345+
self.store_sample("temperature", self.device.temperature())
346+
347+
if "mem_temperature" in self.observables:
348+
self.store_sample("mem_temperature", self.device.mem_temperature())
195349

196-
self.during_results["temperature"].append(temp)
350+
if "core_activity" in self.observables:
351+
self.store_sample("core_activity", self.device.core_activity())
352+
353+
if "mem_activity" in self.observables:
354+
self.store_sample("mem_activity", self.device.mem_activity())
197355

198356
def after_finish(self):
199-
self.during()
357+
before = self.energy_after_start
358+
after = self.device.total_energy_usage()
359+
self.sample_metrics()
360+
361+
diff = np.uint64(after["energy_accumulator"]) - np.uint64(before["energy_accumulator"])
362+
elapsed_ns = np.uint64(after["timestamp"]) - np.uint64(before["timestamp"])
363+
resolution = before["counter_resolution"]
364+
energy_uj = float(diff) * float(resolution)
200365

201366
# Energy is an exception as it does not need integration over time
202367
if "energy" in self.observables:
203-
before = self.energy_after_start
204-
after = amdsmi.amdsmi_get_energy_count(self.device)
205-
206-
# This field changed names in rocm 6.4
207-
if "energy_accumulator" in before:
208-
energy_field = "energy_accumulator"
209-
elif "power" in before:
210-
energy_field = "power"
211-
else:
212-
raise RuntimeError(f"invalid result from amdsmi_get_energy_count: {before}")
368+
# microJ * 1e-6 -> J
369+
self.results_per_iteration[self.field_name("energy")].append(energy_uj * 1e-6)
213370

214-
diff = np.uint64(after[energy_field]) - np.uint64(before[energy_field])
215-
resolution = before["counter_resolution"]
216-
energy_mj = float(diff) * float(resolution)
371+
if "power" in self.observables:
372+
self.results_per_iteration[self.field_name("power")].append(energy_uj / elapsed_ns * 1e3)
217373

218-
# microJ * 1e-6 -> J
219-
self.iteration_results["energy"].append(energy_mj * 1e-6)
374+
# normalize timestamps to [0, 1] such that integral (trapezoid) is the mean
375+
xs = np.array(self.sample_timestamps)
376+
xs = (xs - xs.min()) / (xs.max() - xs.min())
220377

221-
# For the others, we integrate over time and take the average
222-
x = self.during_timestamps
223-
for key, values in self.during_results.items():
224-
# np.trapezoid was np.trapz in older versions of np
225-
avg = np.trapezoid(values, x) / np.ptp(x)
226-
self.iteration_results[key].append(avg)
378+
for key, values in self.sample_values.items():
379+
# Could not sample, skip field
380+
if not values:
381+
continue
227382

228-
def get_results(self):
229-
results = dict()
383+
# If all values are the same, take that value directly.
384+
# This preserve that value bitwise exactly and prevents
385+
# rounding errors that occur in trapezoid
386+
if all(v == values[0] for v in values):
387+
result = values[0]
388+
else:
389+
result = trapezoid(values, x=xs)
230390

231-
for key in list(self.iteration_results):
232-
# Average of results at each iteration
233-
avg = np.average(self.iteration_results[key])
391+
self.results_per_iteration[key].append(result)
234392

235-
# Reset to empty
236-
self.iteration_results[key] = []
393+
def get_results(self):
394+
results = dict()
237395

238-
if self.prefix:
239-
results[f"{self.prefix}_{key}"] = avg
240-
else:
241-
results[key] = avg
396+
for key in list(self.results_per_iteration):
397+
# Take average and reset!
398+
results[key] = np.average(self.results_per_iteration[key])
399+
self.results_per_iteration[key] = []
242400

243401
return results

0 commit comments

Comments
 (0)