Skip to content

Commit 55254f3

Browse files
Merge pull request #18 from stackav-oss/feature/jmanning/wheel-fixing
Platform-agnostic wheel fixes
2 parents c352cad + a6b6383 commit 55254f3

25 files changed

Lines changed: 96 additions & 117 deletions

benchmarks/bnb_dequantize_blockwise_benchmark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,7 @@ def main( # noqa: PLR0913
187187
error_msg = "bitsandbytes must be installed and enabled via CONCH_ENABLE_BNB=1"
188188
raise NotImplementedError(error_msg)
189189

190-
from bitsandbytes.functional import ( # type: ignore[import-not-found, import-untyped, unused-ignore] # isort:skip
191-
dequantize_4bit as bnb_dequantize_4bit,
192-
)
190+
from bitsandbytes.functional import dequantize_4bit as bnb_dequantize_4bit
193191
from bitsandbytes.functional import quantize_4bit as bnb_quantize_4bit
194192

195193
bnb_quantized, bnb_state = bnb_quantize_4bit(

benchmarks/bnb_quantize_blockwise_benchmark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,7 @@ def main( # noqa: PLR0913
177177
error_msg = "bitsandbytes must be installed and enabled via CONCH_ENABLE_BNB=1"
178178
raise NotImplementedError(error_msg)
179179

180-
from bitsandbytes.functional import ( # type: ignore[import-not-found, import-untyped, unused-ignore] # isort:skip
181-
quantize_4bit as bnb_quantize_4bit,
182-
)
180+
from bitsandbytes.functional import quantize_4bit as bnb_quantize_4bit
183181

184182
bnb_output, bnb_state = bnb_quantize_4bit(
185183
x,

benchmarks/mixed_precision_gemm_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1919

2020
if envs.CONCH_ENABLE_VLLM and current_platform.has_cuda():
21-
from vllm import _custom_ops as vllm_custom_ops # type: ignore[import-not-found, unused-ignore]
21+
from vllm import _custom_ops as vllm_custom_ops
2222
else:
2323
vllm_custom_ops = None # type: ignore[assignment, unused-ignore]
2424

benchmarks/paged_attention_benchmark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1616

1717
if envs.CONCH_ENABLE_VLLM and current_platform.has_cuda():
18-
from vllm._custom_ops import (
19-
paged_attention_v2 as vllm_paged_attention_v2, # type: ignore[import-not-found, import-untyped, unused-ignore]
20-
)
18+
from vllm._custom_ops import paged_attention_v2 as vllm_paged_attention_v2
2119
else:
2220
vllm_paged_attention_v2 = None # type: ignore[assignment, unused-ignore]
2321

benchmarks/paged_attention_vs_flash_benchmark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1616

1717
if envs.CONCH_ENABLE_VLLM and current_platform.is_nvidia():
18-
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined, import-not-found, import-untyped, unused-ignore] # isort:skip
19-
flash_attn_with_kvcache,
20-
)
18+
from vllm.vllm_flash_attn import flash_attn_with_kvcache # type: ignore[attr-defined, unused-ignore]
2119
else:
2220
flash_attn_with_kvcache = None # type: ignore[assignment, unused-ignore]
2321

benchmarks/varlen_attention_benchmark.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
from conch.utils.benchmark import BenchmarkMetadata, benchmark_it
1717

1818
if envs.CONCH_ENABLE_VLLM and current_platform.is_nvidia():
19-
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined, import-not-found, unused-ignore] # isort:skip
20-
flash_attn_varlen_func,
21-
)
19+
from vllm.vllm_flash_attn import flash_attn_varlen_func # type: ignore[attr-defined, unused-ignore]
2220
else:
2321
flash_attn_varlen_func = None # type: ignore[assignment, unused-ignore]
2422

conch/reference/activation/gelu_tanh_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _gelu_tanh_and_mul_pytorch_ref(x: torch.Tensor) -> torch.Tensor:
1717

1818
def _gelu_tanh_and_mul_vllm_ref(x: torch.Tensor) -> torch.Tensor:
1919
"""vLLM reference gelu_tanh_and_mul impl."""
20-
from vllm.model_executor.layers.activation import GeluAndMul # type: ignore[import-not-found, unused-ignore]
20+
from vllm.model_executor.layers.activation import GeluAndMul
2121

2222
gelu_layer = GeluAndMul("tanh")
2323
return gelu_layer.forward_cuda(x) # type: ignore[no-any-return, unused-ignore]

conch/reference/activation/silu_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _silu_and_mul_pytorch_ref(x: torch.Tensor) -> torch.Tensor:
1717

1818
def _silu_and_mul_vllm_ref(x: torch.Tensor) -> torch.Tensor:
1919
"""vLLM reference silu and mul implementation."""
20-
from vllm.model_executor.layers.activation import SiluAndMul # type: ignore[import-not-found, unused-ignore]
20+
from vllm.model_executor.layers.activation import SiluAndMul
2121

2222
silu_layer = SiluAndMul() # type: ignore[no-untyped-call, unused-ignore]
2323
return silu_layer.forward_cuda(x) # type: ignore[no-any-return, unused-ignore]

conch/reference/embedding/rotary_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,18 @@ def _rotary_embedding_vllm_ref(
9999
offsets: torch.Tensor | None = None,
100100
) -> tuple[torch.Tensor, torch.Tensor]:
101101
"""vLLM reference rotary_embedding impl."""
102-
from vllm import _custom_ops as ops # type: ignore[import-not-found, unused-ignore]
102+
from vllm import _custom_ops as vllm_custom_ops
103103

104104
cos_sin_cache = cos_sin_cache.to(query.device, dtype=query.dtype)
105105

106-
# ops.rotary_embedding()/batched_rotary_embedding()
106+
# vllm_custom_ops.rotary_embedding()/batched_rotary_embedding()
107107
# are in-place operations that update the query and key tensors.
108108
if offsets is not None:
109-
ops.batched_rotary_embedding(
109+
vllm_custom_ops.batched_rotary_embedding(
110110
positions, query, key, head_size, cos_sin_cache, is_neox_style, rotary_dim, offsets
111111
)
112112
else:
113-
ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox_style)
113+
vllm_custom_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox_style)
114114

115115
return query, key
116116

conch/reference/normalization/gemma_rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _gemma_rms_norm_vllm_ref(
3737
residual: torch.Tensor | None,
3838
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
3939
"""vLLM reference gemma_rms_norm impl."""
40-
from vllm.model_executor.layers.layernorm import GemmaRMSNorm # type: ignore[import-not-found, unused-ignore]
40+
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
4141

4242
layer = GemmaRMSNorm(hidden_size=weight.size(0), eps=variance_epsilon)
4343
layer.weight = torch.nn.Parameter(weight)

0 commit comments

Comments
 (0)