Skip to content

Commit 9e0d798

Browse files
Merge pull request #30 from stackav-oss/feature/jmanning/cleanup-mp-gemm
Tune MP GEMM kernel
2 parents 2c95868 + 389c0c8 commit 9e0d798

2 files changed

Lines changed: 75 additions & 85 deletions

File tree

conch/kernels/quantization/gemm.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -135,33 +135,13 @@ def _get_metadata_eviction_policy() -> str:
135135

136136
def _get_tuning_parameters() -> dict[str, int]:
137137
"""Get block sizes/tuning parameters for current device."""
138-
device_name = current_platform.get_device_name()
139-
140-
if "H100" in device_name:
141-
return {
142-
"cxpr_block_size_m": 128,
143-
"cxpr_block_size_n": 128,
144-
"cxpr_block_size_k": 128,
145-
"cxpr_group_size_m": 8,
146-
"num_warps": 8,
147-
"num_stages": 2,
148-
}
149-
150-
if "MI300X" in device_name:
151-
return {
152-
"cxpr_block_size_m": 128,
153-
"cxpr_block_size_n": 64,
154-
"cxpr_block_size_k": 128,
155-
"cxpr_group_size_m": 16,
156-
"num_warps": 8,
157-
"num_stages": 2,
158-
}
159-
160138
return {
161-
"cxpr_block_size_m": 64,
139+
"cxpr_block_size_m": 128,
162140
"cxpr_block_size_n": 64,
163-
"cxpr_block_size_k": 32,
164-
"cxpr_group_size_m": 8,
141+
"cxpr_block_size_k": 64,
142+
"cxpr_group_size_m": 16,
143+
"num_warps": 8,
144+
"num_stages": 2,
165145
}
166146

167147

conch/ops/quantization/gemm.py

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -51,31 +51,33 @@ def create_mixed_precision_metadata(
5151
acc_dtype: torch.dtype | None = None,
5252
meta_dtype: torch.dtype | None = None,
5353
scaled_activations: bool = False,
54+
strict: bool = False,
5455
) -> MixedPrecisionMatmulMetadata:
5556
"""Verify sizes and dtypes of tensors and deduce metadata parameters."""
56-
expected_input_matrix_rank: Final = 2
57+
if strict:
58+
expected_input_matrix_rank: Final = 2
5759

58-
if (x_rank := len(x.shape)) != expected_input_matrix_rank:
59-
error_msg = f"Unexpected number of dimensions of input tensor x: {x_rank}"
60-
raise ValueError(error_msg)
60+
if (x_rank := len(x.shape)) != expected_input_matrix_rank:
61+
error_msg = f"Unexpected number of dimensions of input tensor x: {x_rank}"
62+
raise ValueError(error_msg)
6163

62-
if (w_q_packed_rank := len(w_q_packed.shape)) != expected_input_matrix_rank:
63-
error_msg = f"Unexpected number of dimensions of input tensor w_q_packed: {w_q_packed_rank}"
64-
raise ValueError(error_msg)
64+
if (w_q_packed_rank := len(w_q_packed.shape)) != expected_input_matrix_rank:
65+
error_msg = f"Unexpected number of dimensions of input tensor w_q_packed: {w_q_packed_rank}"
66+
raise ValueError(error_msg)
6567

66-
if (w_s_rank := len(w_s.shape)) != expected_input_matrix_rank:
67-
error_msg = f"Unexpected number of dimensions of input tensor w_s: {w_s_rank}"
68-
raise ValueError(error_msg)
68+
if (w_s_rank := len(w_s.shape)) != expected_input_matrix_rank:
69+
error_msg = f"Unexpected number of dimensions of input tensor w_s: {w_s_rank}"
70+
raise ValueError(error_msg)
6971

70-
if w_zp is not None and (w_zp_rank := len(w_zp.shape)) != expected_input_matrix_rank:
71-
error_msg = f"Unexpected number of dimensions of input tensor w_zp: {w_zp_rank}"
72-
raise ValueError(error_msg)
72+
if w_zp is not None and (w_zp_rank := len(w_zp.shape)) != expected_input_matrix_rank:
73+
error_msg = f"Unexpected number of dimensions of input tensor w_zp: {w_zp_rank}"
74+
raise ValueError(error_msg)
7375

74-
# Expecting some form of 32-bit packing
75-
expected_packed_dtypes: Final = [torch.uint32, torch.int32]
76-
if (packed_dtype := w_q_packed.dtype) not in expected_packed_dtypes:
77-
error_msg = f"Invalid datatype for packed weights: {packed_dtype}"
78-
raise ValueError(error_msg)
76+
# Expecting some form of 32-bit packing
77+
expected_packed_dtypes: Final = [torch.uint32, torch.int32]
78+
if (packed_dtype := w_q_packed.dtype) not in expected_packed_dtypes:
79+
error_msg = f"Invalid datatype for packed weights: {packed_dtype}"
80+
raise ValueError(error_msg)
7981

8082
# Assume 32-bit packing
8183
packed_bitwidth: Final = 32
@@ -86,25 +88,27 @@ def create_mixed_precision_metadata(
8688

8789
unpack_mask = 2**weight_size_bits - 1
8890

89-
# Verify shape of w_s
90-
expected_scales_shape: Final = (k_dim // group_size, n_dim)
91-
if (scales_shape := w_s.shape) != expected_scales_shape:
92-
error_msg = f"Invalid w_s shape (expected: {expected_scales_shape}, actual: {scales_shape})"
93-
raise ValueError(error_msg)
94-
9591
# Check if zeros is a scalar value
9692
zero_is_scalar = False if w_zp is None else w_zp.numel() == 1
97-
# Expected shape of zeros tensor if 1) it is not scalar 2) it is not None
98-
expected_zeros_shape: Final = (k_dim // group_size, n_dim)
99-
# Verify shape of w_zp
100-
if not zero_is_scalar and w_zp is not None and (zeros_shape := w_zp.shape) != expected_zeros_shape:
101-
error_msg = f"Invalid w_zp shape (expected: {expected_zeros_shape}, actual: {zeros_shape})"
102-
raise ValueError(error_msg)
103-
104-
# Not supporting scaled activations right now, but we can add support later if needed. This simplifies the interface
105-
if scaled_activations:
106-
error_msg = "Scaled activations not yet implemented (need to deduce correct channel_scale_mode)"
107-
raise NotImplementedError(error_msg)
93+
94+
if strict:
95+
# Verify shape of w_s
96+
expected_scales_shape: Final = (k_dim // group_size, n_dim)
97+
if (scales_shape := w_s.shape) != expected_scales_shape:
98+
error_msg = f"Invalid w_s shape (expected: {expected_scales_shape}, actual: {scales_shape})"
99+
raise ValueError(error_msg)
100+
101+
# Expected shape of zeros tensor if 1) it is not scalar 2) it is not None
102+
expected_zeros_shape: Final = (k_dim // group_size, n_dim)
103+
# Verify shape of w_zp
104+
if not zero_is_scalar and w_zp is not None and (zeros_shape := w_zp.shape) != expected_zeros_shape:
105+
error_msg = f"Invalid w_zp shape (expected: {expected_zeros_shape}, actual: {zeros_shape})"
106+
raise ValueError(error_msg)
107+
108+
# Not supporting scaled activations right now, but we can add support later if needed. This simplifies the interface
109+
if scaled_activations:
110+
error_msg = "Scaled activations not yet implemented (need to deduce correct channel_scale_mode)"
111+
raise NotImplementedError(error_msg)
108112

109113
return MixedPrecisionMatmulMetadata(
110114
m_dim=m_dim,
@@ -139,6 +143,7 @@ def mixed_precision_gemm(
139143
acc_dtype: torch.dtype | None = None,
140144
meta_dtype: torch.dtype | None = None,
141145
scaled_activations: bool = False,
146+
strict: bool = False,
142147
) -> torch.Tensor:
143148
"""Mixed precision GEMM operation."""
144149
metadata = create_mixed_precision_metadata(
@@ -153,6 +158,7 @@ def mixed_precision_gemm(
153158
acc_dtype=acc_dtype,
154159
meta_dtype=meta_dtype,
155160
scaled_activations=scaled_activations,
161+
strict=strict,
156162
)
157163

158164
output = torch.zeros((metadata.m_dim, metadata.n_dim), device=x.device, dtype=metadata.output_dtype)
@@ -168,42 +174,45 @@ def create_scaled_metadata(
168174
scale_a: torch.Tensor,
169175
scale_b: torch.Tensor,
170176
output_dtype: torch.dtype,
177+
strict: bool = False,
171178
) -> ScaledMatmulMetadata:
172179
"""Verify sizes and dtypes of tensors and deduce metadata parameters."""
173-
expected_input_matrix_rank: Final = 2
180+
if strict:
181+
expected_input_matrix_rank: Final = 2
174182

175-
if (a_rank := len(a.shape)) != expected_input_matrix_rank:
176-
error_msg = f"Unexpected number of dimensions of input tensor a: {a_rank}"
177-
raise ValueError(error_msg)
183+
if (a_rank := len(a.shape)) != expected_input_matrix_rank:
184+
error_msg = f"Unexpected number of dimensions of input tensor a: {a_rank}"
185+
raise ValueError(error_msg)
178186

179-
if (b_rank := len(b.shape)) != expected_input_matrix_rank:
180-
error_msg = f"Unexpected number of dimensions of input tensor b: {b_rank}"
181-
raise ValueError(error_msg)
187+
if (b_rank := len(b.shape)) != expected_input_matrix_rank:
188+
error_msg = f"Unexpected number of dimensions of input tensor b: {b_rank}"
189+
raise ValueError(error_msg)
182190

183-
if a.dtype != b.dtype:
184-
error_msg = f"Input tensors a and b must have the same datatype (a: {a.dtype}, b: {b.dtype})"
185-
raise ValueError(error_msg)
191+
if a.dtype != b.dtype:
192+
error_msg = f"Input tensors a and b must have the same datatype (a: {a.dtype}, b: {b.dtype})"
193+
raise ValueError(error_msg)
186194

187195
m_dim, k_dim = a.shape
188196
_, n_dim = b.shape
189197

190-
if scale_a.numel() != 1:
191-
if (scale_a_rank := len(scale_a.shape)) != expected_input_matrix_rank:
192-
error_msg = f"Unexpected number of dimensions of input tensor scale_a: {scale_a_rank}"
193-
raise ValueError(error_msg)
198+
if strict:
199+
if scale_a.numel() != 1:
200+
if (scale_a_rank := len(scale_a.shape)) != expected_input_matrix_rank:
201+
error_msg = f"Unexpected number of dimensions of input tensor scale_a: {scale_a_rank}"
202+
raise ValueError(error_msg)
194203

195-
if scale_a.shape[0] != m_dim:
196-
error_msg = f"Invalid scale_a shape (expected: ({m_dim},), actual: {scale_a.shape})"
197-
raise ValueError(error_msg)
204+
if scale_a.shape[0] != m_dim:
205+
error_msg = f"Invalid scale_a shape (expected: ({m_dim},), actual: {scale_a.shape})"
206+
raise ValueError(error_msg)
198207

199-
if scale_b.numel() != 1:
200-
if (scale_b_rank := len(scale_b.shape)) != expected_input_matrix_rank:
201-
error_msg = f"Unexpected number of dimensions of input tensor scale_b: {scale_b_rank}"
202-
raise ValueError(error_msg)
208+
if scale_b.numel() != 1:
209+
if (scale_b_rank := len(scale_b.shape)) != expected_input_matrix_rank:
210+
error_msg = f"Unexpected number of dimensions of input tensor scale_b: {scale_b_rank}"
211+
raise ValueError(error_msg)
203212

204-
if scale_b.shape[0] != n_dim:
205-
error_msg = f"Invalid scale_b shape (expected: ({n_dim},), actual: {scale_b.shape})"
206-
raise ValueError(error_msg)
213+
if scale_b.shape[0] != n_dim:
214+
error_msg = f"Invalid scale_b shape (expected: ({n_dim},), actual: {scale_b.shape})"
215+
raise ValueError(error_msg)
207216

208217
return ScaledMatmulMetadata(
209218
m_dim=m_dim,
@@ -228,9 +237,10 @@ def scaled_gemm(
228237
scale_b: torch.Tensor,
229238
output_dtype: torch.dtype,
230239
bias: torch.Tensor | None = None,
240+
strict: bool = False,
231241
) -> torch.Tensor:
232242
"""Scaled GEMM operation."""
233-
metadata = create_scaled_metadata(a, b, scale_a, scale_b, output_dtype)
243+
metadata = create_scaled_metadata(a, b, scale_a, scale_b, output_dtype, strict=strict)
234244

235245
output = torch.zeros((metadata.m_dim, metadata.n_dim), device=a.device, dtype=output_dtype)
236246

0 commit comments

Comments
 (0)