Skip to content

Commit 3e3433d

Browse files
Fixing the remaining pylint errors or adding exceptions where they can't be fixed
1 parent b1304eb commit 3e3433d

19 files changed

Lines changed: 281 additions & 265 deletions

Ironwood/src/benchmark_collectives.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:
5757

5858

5959
def get_sharding_axis(dim_str: str, mesh: Mesh) -> tuple[str, ...]:
60-
"""Computes sharding axis names from dimension string like '1x4' and mesh."""
60+
"""Computes sharding axis names from dimension string and mesh."""
61+
# Example of a dimension string is '1x4'
6162
dim_tuple = dim_str.split("x")
6263
dim_tuple = tuple(int(dim) for dim in dim_tuple)
6364
sharding_axis = tuple(
@@ -203,6 +204,7 @@ def psum_benchmark(
203204
num_runs: int = 1,
204205
trace_dir: str = None,
205206
) -> Dict[str, Any]:
207+
# pylint: disable=unused-argument
206208
"""Benchmarks the psum collective operation.
207209
208210
Args:
@@ -354,6 +356,7 @@ def psum_scatter_benchmark(
354356
num_runs: int = 1,
355357
trace_dir: str = None,
356358
) -> Dict[str, Any]:
359+
# pylint: disable=unused-argument
357360
"""Benchmarks the psum_scatter collective operation.
358361
359362
Args:
@@ -376,7 +379,7 @@ def psum_scatter_benchmark(
376379
"--xla_sc_disable_megacore_partitioning=true",
377380
"--xla_tpu_disable_sparse_core_collective_offload_remover=true",
378381
"--xla_tpu_enable_reduce_scatter_offload_tracing=true",
379-
"--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true",
382+
"--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", # pylint: disable=line-too-long
380383
"--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true",
381384
"--xla_tpu_enable_sparse_core_reduce_scatter_v2=true",
382385
"--xla_tpu_use_tc_device_shape_on_sc=true",
@@ -470,6 +473,7 @@ def all_gather_benchmark(
470473
num_runs: int = 1,
471474
trace_dir: str = None,
472475
) -> Dict[str, Any]:
476+
# pylint: disable=unused-argument
473477
"""Benchmarks the all_gather collective operation.
474478
475479
Args:
@@ -586,6 +590,7 @@ def all_to_all_benchmark(
586590
num_runs: int = 1,
587591
trace_dir: str = None,
588592
) -> Dict[str, Any]:
593+
# pylint: disable=unused-argument
589594
"""Benchmarks the all_to_all collective operation.
590595
591596
Args:

Ironwood/src/benchmark_compute.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,11 @@ def swiglu_fwd(
318318

319319
def f(x):
320320
with jax.named_scope(MARKER):
321-
A, B = jnp.split(x, 2, axis=-1)
322-
A_fp32 = A.astype(jnp.float32)
323-
B_fp32 = B.astype(jnp.float32)
324-
Y_fp32 = jax.nn.silu(A_fp32) * B_fp32
325-
return Y_fp32.astype(jnp.bfloat16)
321+
a, b = jnp.split(x, 2, axis=-1)
322+
a_fp32 = a.astype(jnp.float32)
323+
b_fp32 = b.astype(jnp.float32)
324+
y_fp32 = jax.nn.silu(a_fp32) * b_fp32
325+
return y_fp32.astype(jnp.bfloat16)
326326

327327
mesh = create_mesh(SHARDING_STRATEGY)
328328
x_sharding = get_rowwise_named_shading(mesh, SHARDING_STRATEGY)
@@ -379,16 +379,17 @@ def swiglu_bwd(
379379
num_runs: int = 1,
380380
trace_dir: str = None,
381381
) -> Dict[str, Any]:
382+
# pylint: disable=invalid-name
382383
"""
383384
Inverse of swiglu_fwd
384385
"""
385386

386387
def f_fwd(x):
387-
A, B = jnp.split(x, 2, axis=-1)
388-
A_fp32 = A.astype(jnp.float32)
389-
B_fp32 = B.astype(jnp.float32)
390-
Y_fp32 = jax.nn.silu(A_fp32) * B_fp32
391-
return Y_fp32.astype(jnp.bfloat16)
388+
a, b = jnp.split(x, 2, axis=-1)
389+
a_fp32 = a.astype(jnp.float32)
390+
b_fp32 = b.astype(jnp.float32)
391+
y_fp32 = jax.nn.silu(a_fp32) * b_fp32
392+
return y_fp32.astype(jnp.bfloat16)
392393

393394
def f(x: jax.Array, dy: jax.Array) -> jax.Array:
394395
"""
@@ -397,7 +398,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array:
397398
"""
398399
# Get the VJP "pullback" function
399400
# We ignore the forward result (_y)
400-
_y, pullback_fn = jax.vjp(f_fwd, x)
401+
# pylint: disable=unused-variable,invalid-name
402+
_y, pullback_fn = jax.vjp(
403+
f_fwd, x
404+
)
401405
with jax.named_scope(MARKER):
402406
# Call the pullback function with the upstream gradient
403407
# This IS the backward pass.
@@ -555,7 +559,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array:
555559
"""
556560
# Get the VJP "pullback" function
557561
# We ignore the forward result (_y)
558-
_y, pullback_fn = jax.vjp(f_fwd, x)
562+
# pylint: disable=unused-variable,invalid-name
563+
_y, pullback_fn = jax.vjp(
564+
f_fwd, x
565+
)
559566
with jax.named_scope(MARKER):
560567
# Call the pullback function with the upstream gradient
561568
# This IS the backward pass.

Ironwood/src/benchmark_gemm.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def gemm_multiple_run(
6969
) -> Dict[str, Any]:
7070
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K> dtype x IN1<N, K>:dtype."""
7171

72-
"""Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16."""
72+
# Accumulation is FP32. Current supported dtype: float8_e4m3fn,
73+
# bfloat16.
7374

7475
def f(x, y):
7576
with jax.named_scope(MARKER):
@@ -170,8 +171,7 @@ def gemm_simple(
170171
trace_dir: str = None,
171172
) -> Dict[str, Any]:
172173
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K>:FP8 x IN1<N, K>:FP8."""
173-
174-
"""Accumulation is FP32."""
174+
# Accumulation is FP32.
175175

176176
def f(x, y):
177177
with jax.named_scope(MARKER):
@@ -266,8 +266,7 @@ def gemm_simple_with_dtype(
266266
trace_dir: str = None,
267267
) -> Dict[str, Any]:
268268
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K>:FP8 x IN1<N, K>:FP8."""
269-
270-
"""Accumulation is FP32."""
269+
# Accumulation is FP32.
271270

272271
# Convert string dtypes to jnp dtypes
273272
lhs_dtype = str_to_dtype(in_dtype_str)
@@ -368,7 +367,8 @@ def gemm_simple_with_dtype_calculate_metrics(
368367
def gemm(
369368
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
370369
) -> Dict[str, Any]:
371-
"""OUT<M, N>:BF16 = matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) * outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
370+
"""OUT<M, N>:BF16 = matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) *
371+
outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
372372

373373
def f(x, y, scale_m, scale_n):
374374
with jax.named_scope(MARKER):
@@ -473,7 +473,8 @@ def gemm_accum(
473473
num_runs: int = 1,
474474
trace_dir: str = None,
475475
) -> Dict[str, Any]:
476-
"""OUT<M, N>:FP32 += matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) * outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
476+
"""OUT<M, N>:FP32 += matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) *
477+
outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
477478

478479
def f(out_buffer, x, y, scale_m, scale_n):
479480
with jax.named_scope(MARKER):

Ironwood/src/benchmark_gemm_numerics.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,7 @@ def gemm_fp8_b128_fp32(
273273
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
274274
) -> Dict[str, Any]:
275275
"""FP8 GEMM as DeepSeek-stype quantization, block size: 1x128."""
276-
277-
"""Use dynamic scaling factors."""
276+
# Use dynamic scaling factors.
278277

279278
def f(x, y):
280279
with jax.named_scope(MARKER):
@@ -387,8 +386,7 @@ def gemm_fp8_b128_fp32_static_scaling(
387386
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
388387
) -> Dict[str, Any]:
389388
"""FP8 GEMM as DeepSeek-stype quantization, block size: 1x128."""
390-
391-
"""Use static scaling factors."""
389+
# Use static scaling factors.
392390

393391
def f(x, y):
394392
with jax.named_scope(MARKER):

Ironwood/src/benchmark_hbm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def single_device_hbm_copy_calculate_metrics(
8989
)
9090
print(
9191
f"Tensor size: {tensor_size_bytes / 1024**2} MB, "
92-
f"time taken (median): {time_statistics.statistics['p50']:.4f} ms, "
93-
f"bandwidth (median): {statistics.statistics['p50']:.3f} GB/s"
92+
f"time taken (median): {time_statistics.statistics["p50"]:.4f} ms, "
93+
f"bandwidth (median): {statistics.statistics["p50"]:.3f} GB/s"
9494
)
9595
print()
9696
# Gather the metrics to report.

Ironwood/src/benchmark_host_device.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
1+
"""
2+
Benchmarks Host-to-Device and Device-to-Host transfer performance
3+
(Simple Baseline).
4+
"""
25

36
import time
47
import os
@@ -123,8 +126,8 @@ def add_metric(name, ms_list):
123126
]
124127
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
125128
print(
126-
f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, "
127-
f"P95: {stats_bw.statistics['p95']}",
129+
f"{name}_bw (GiB/s) median: {stats_bw.statistics["p50"]}, "
130+
f"P95: {stats_bw.statistics["p95"]}",
128131
flush=True,
129132
)
130133
metrics.update(stats_bw.serialize_statistics())

Ironwood/src/benchmark_inference_compute.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -347,89 +347,3 @@ def sigmoid_calculate_metrics(
347347
dtype=dtype.dtype.name,
348348
)
349349

350-
351-
# def get_output_named_shading(mesh, strategy: ShardingStrategy):
352-
# match strategy:
353-
# case ShardingStrategy.NO_SHARDING:
354-
# return NamedSharding(mesh, P(None))
355-
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
356-
# return NamedSharding(mesh, P("device"))
357-
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
358-
# return NamedSharding(mesh, P("device"))
359-
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
360-
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
361-
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
362-
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
363-
364-
# def get_out_sharding(strategy: ShardingStrategy):
365-
# match strategy:
366-
# case ShardingStrategy.NO_SHARDING:
367-
# return P(None)
368-
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
369-
# return P("device")
370-
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
371-
# return P("device")
372-
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
373-
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
374-
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
375-
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
376-
377-
# def add(m: int, dtype: jnp.dtype, num_runs: int = 1, trace_dir: str = None,
378-
# ) -> Dict[str, Any]:
379-
# """
380-
# Z = X + Y
381-
# """
382-
# def f(x, y):
383-
# with jax.named_scope(MARKER):
384-
# return x + y
385-
386-
# mesh = create_mesh(SHARDING_STRATEGY)
387-
# x_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
388-
# y_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
389-
# out_sharding = get_out_sharding(SHARDING_STRATEGY)
390-
# jit_sharded_f = jax.jit(
391-
# shard_map(
392-
# f,
393-
# mesh,
394-
# in_specs=(x_sharding.spec, y_sharding.spec),
395-
# out_specs=out_sharding,
396-
# check_rep=False,
397-
# )
398-
# )
399-
# x_shape = (m)
400-
# y_shape = (m)
401-
# x_dtype = dtype
402-
# y_dtype = dtype
403-
404-
# key = jax.random.key(SEED)
405-
406-
# def data_generator():
407-
# """Creates new random data on host and puts it on device."""
408-
# nonlocal key # Use and update the outer 'key'
409-
# key, k1, k2 = jax.random.split(key, 3)
410-
411-
# x_host = jax.random.normal(k1, x_shape).astype(x_dtype)
412-
# y_host = jax.random.normal(k2, y_shape).astype(y_dtype)
413-
414-
# x_device = jax.device_put(x_host, x_sharding)
415-
# y_device = jax.device_put(y_host, y_sharding)
416-
417-
# return (x_device, y_device)
418-
419-
# time_ms_list = iteration_timeit(
420-
# jit_sharded_f,
421-
# data_generator,
422-
# matrix_dim=f"{m}",
423-
# tries=num_runs,
424-
# task="add",
425-
# trace_dir=trace_dir,
426-
# )
427-
# return {"time_ms_list": time_ms_list}
428-
429-
# def add_calculate_metrics(
430-
# m: int, dtype: jnp.dtype, time_ms_list: list[float]
431-
# ) -> Dict[str, Any]:
432-
# scale = 2 if dtype == jnp.bfloat16 else 1
433-
# total_bytes = scale * 3 * m
434-
# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
435-
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name)

Ironwood/src/benchmark_send_recv.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def send_recv_benchmark(
8484
dtype: jnp.dtype,
8585
trace_dir: str,
8686
):
87-
"""Runs p2p communication, sending tensor_size_bytes from source to target device."""
87+
# pylint: disable=unused-argument
88+
"""
89+
Runs p2p communication, sending tensor_size_bytes from source to target
90+
device.
91+
"""
8892
device_count = jax.local_device_count()
8993
devices = mesh_utils.create_device_mesh((device_count,))
9094
mesh = jax.sharding.Mesh(devices, "x")
@@ -120,14 +124,14 @@ def p2p_send(source_id, target_id):
120124
target_recv_sizes,
121125
no_recvs,
122126
)
123-
input = jax.random.normal(
127+
random_input = jax.random.normal(
124128
jax.random.key(0), (1, 8, last_dim), dtype=dtype
125129
)
126130
output = jnp.zeros((1, 8, last_dim), dtype=dtype)
127131

128132
with jax.named_scope(MARKER):
129133
ra2a = jax.lax.ragged_all_to_all(
130-
operand=input,
134+
operand=random_input,
131135
output=output,
132136
input_offsets=input_offsets,
133137
send_sizes=final_send_sizes,
@@ -158,10 +162,10 @@ def p2p_send(source_id, target_id):
158162

159163

160164
def send_recv_benchmark_calculate_metrics(
161-
source_id: int,
162-
target_id: int,
165+
source_id: int, # pylint: disable=unused-argument
166+
target_id: int, # pylint: disable=unused-argument
163167
num_elements: int,
164-
n_repeats: int,
168+
n_repeats: int, # pylint: disable=unused-argument
165169
dtype: jnp.dtype,
166170
runtime_ms: float,
167171
) -> Tuple[Dict[str, Any], Dict[str, Any]]:

0 commit comments

Comments
 (0)