Skip to content

Commit ed940ae

Browse files
committed
Remove merge conflict lines
1 parent 641c73b commit ed940ae

1 file changed

Lines changed: 0 additions & 161 deletions

File tree

Ironwood/src/benchmark_host_device.py

Lines changed: 0 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -38,167 +38,6 @@ def benchmark_host_device(
3838
num_devices_to_perform_h2d = 1
3939
target_devices = jax.devices()[:num_devices_to_perform_h2d]
4040

41-
<<<<<<< Updated upstream
42-
print(
43-
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
44-
flush=True
45-
)
46-
47-
# Performance Lists
48-
h2d_perf, d2h_perf = [], []
49-
50-
# Profiling Context
51-
import contextlib
52-
if trace_dir:
53-
profiler_context = jax.profiler.trace(trace_dir)
54-
else:
55-
profiler_context = contextlib.nullcontext()
56-
57-
with profiler_context:
58-
# Warmup
59-
for _ in range(2):
60-
device_array = jax.device_put(host_data)
61-
device_array.block_until_ready()
62-
host_out = np.array(device_array)
63-
device_array.delete()
64-
del host_out
65-
66-
for i in range(num_runs):
67-
# Step Context
68-
if trace_dir:
69-
step_context = jax.profiler.StepTraceAnnotation("host_device", step_num=i)
70-
else:
71-
step_context = contextlib.nullcontext()
72-
73-
with step_context:
74-
# H2D
75-
if h2d_type == "simple":
76-
t0 = time.perf_counter()
77-
# Simple device_put
78-
device_array = jax.device_put(host_data)
79-
device_array.block_until_ready()
80-
t1 = time.perf_counter()
81-
82-
# Verify H2D shape
83-
assert device_array.shape == host_data.shape
84-
85-
h2d_perf.append((t1 - t0) * 1000)
86-
87-
# D2H
88-
t2 = time.perf_counter()
89-
90-
# Simple device_get
91-
# Note: device_get returns a numpy array (copy)
92-
_ = jax.device_get(device_array)
93-
94-
t3 = time.perf_counter()
95-
d2h_perf.append((t3 - t2) * 1000)
96-
97-
device_array.delete()
98-
elif h2d_type == "pipelined":
99-
target_chunk_size_mib = 16 # Sweet spot from profiling
100-
num_devices = len(target_devices)
101-
102-
tensors_on_device = []
103-
104-
# Calculate chunks per device
105-
data_per_dev = data_size_mib / num_devices
106-
chunks_per_dev = int(data_per_dev / target_chunk_size_mib)
107-
chunks_per_dev = max(1, chunks_per_dev)
108-
109-
chunks = np.array_split(host_data, chunks_per_dev * num_devices, axis=0)
110-
111-
t0 = time.perf_counter()
112-
if chunks_per_dev > 1:
113-
# We need to map chunks to the correct device
114-
# This simple example assumes chunks are perfectly divisible and ordered
115-
# In production, use `jax.sharding` mesh logic for complex layouts
116-
117-
# approach 1: simple for loop
118-
for idx, chunk in enumerate(chunks):
119-
if num_devices > 1:
120-
dev = target_devices[idx % num_devices]
121-
else:
122-
dev = target_devices[0]
123-
tensors_on_device.append(jax.device_put(chunk, dev))
124-
# Re-assemble array
125-
result = jnp.vstack(tensors_on_device)
126-
# Wait for all chunks to be transferred
127-
result.block_until_ready()
128-
129-
# approach 2: generator (slightly less overhead)
130-
# def chunk_generator(num_devices, chunks_per_dev):
131-
# for n in range(chunks_per_dev):
132-
# for d in range(num_devices):
133-
# # 1. Get the specific small chunk
134-
# chunk = chunks[d*chunks_per_dev+n]
135-
136-
# # 2. Trigger an individual DMA transfer for this specific chunk
137-
# # This is where NUMA-local memory access matters
138-
# yield jax.device_put(chunk, target_devices[d])
139-
140-
# # Re-assemble array
141-
# result = jnp.vstack(list(chunk_generator(num_devices, chunks_per_dev)))
142-
# # Wait for all chunks to be transferred
143-
# result.block_until_ready()
144-
else:
145-
print(f"Warning: {data_size_mib=} is not larger than {target_chunk_size_mib=}, falling back to standard JAX put.")
146-
# Fallback to standard JAX put for small data
147-
result = jax.device_put(host_data, target_devices[0])
148-
result.block_until_ready()
149-
150-
t1 = time.perf_counter()
151-
h2d_perf.append((t1 - t0) * 1000)
152-
153-
# D2H
154-
t2 = time.perf_counter()
155-
# Simple device_get
156-
# Note: device_get returns a numpy array (copy)
157-
_ = jax.device_get(result)
158-
159-
t3 = time.perf_counter()
160-
if not np.allclose(result, host_data):
161-
print("pipelined result not equal to host_data")
162-
d2h_perf.append((t3 - t2) * 1000)
163-
164-
for r in tensors_on_device:
165-
r.delete()
166-
del tensors_on_device
167-
168-
return {
169-
"H2D_Bandwidth_ms": h2d_perf,
170-
"D2H_Bandwidth_ms": d2h_perf,
171-
}
172-
173-
def benchmark_host_device_calculate_metrics(
174-
data_size_mib: int,
175-
H2D_Bandwidth_ms: List[float],
176-
D2H_Bandwidth_ms: List[float],
177-
h2d_type: str = "simple",
178-
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
179-
"""Calculates metrics for Host-Device transfer."""
180-
params = locals().items()
181-
182-
# Filter out list params from metadata to avoid explosion
183-
metadata_keys = {
184-
"data_size_mib",
185-
}
186-
metadata = {k: v for k, v in params if k in metadata_keys}
187-
metadata["dtype"] = "float32"
188-
metadata["h2d_type"] = h2d_type
189-
190-
metrics = {}
191-
192-
def add_metric(name, ms_list):
193-
# Report Bandwidth (GiB/s)
194-
# Handle division by zero if ms is 0
195-
bw_list = [
196-
((data_size_mib / 1024) / (ms / 1000)) if ms > 0 else 0.0
197-
for ms in ms_list
198-
]
199-
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
200-
=======
201-
>>>>>>> Stashed changes
20241
print(
20342
f"Benchmarking Transfer with Data Size: {data_size_mib} MB for {num_runs} iterations with {h2d_type=}",
20443
flush=True

0 commit comments

Comments
 (0)