Skip to content

Commit ccdc35f

Browse files
Provide bugfix for owlv2 in old inference, regarding monkey-patching with torch.compile(...) (#2270)
* Provide bugfix for owlv2 in old inference, regarding monkey-patching with torch.compile * Make linters happy
1 parent 6584f8e commit ccdc35f

4 files changed

Lines changed: 16 additions & 3 deletions

File tree

inference/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.2.4"
1+
__version__ = "1.2.5"
22

33

44
if __name__ == "__main__":

inference/models/owlv2/owlv2.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@
5353
extract_image_payload_and_type,
5454
load_image_rgb,
5555
)
56+
from inference_models.models.owlv2.owlv2_hf import (
57+
monkey_patch_vision_encoder_before_compilation,
58+
)
5659

5760
CPU_IMAGE_EMBED_CACHE_SIZE = OWLV2_CPU_IMAGE_CACHE_SIZE
5861
PRELOADED_HF_MODELS = {}
@@ -120,6 +123,9 @@ def __new__(cls, huggingface_id: str):
120123

121124
if OWLV2_COMPILE_MODEL:
122125
torch._dynamo.config.suppress_errors = True
126+
model._model = monkey_patch_vision_encoder_before_compilation(
127+
model._model
128+
)
123129
model.owlv2.vision_model = torch.compile(model.owlv2.vision_model)
124130
instance.model = model
125131
cls._instances[huggingface_id] = instance

inference/models/owlv2/owlv2_inference_models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
ALLOW_INFERENCE_MODELS_DIRECTLY_ACCESS_LOCAL_PACKAGES,
1717
ALLOW_INFERENCE_MODELS_UNTRUSTED_PACKAGES,
1818
API_KEY,
19-
DEVICE,
2019
DISABLED_INFERENCE_MODELS_BACKENDS,
2120
MAX_DETECTIONS,
2221
OWLV2_COMPILE_MODEL,
@@ -43,7 +42,10 @@
4342
ReferenceBoundingBox,
4443
ReferenceExample,
4544
)
46-
from inference_models.models.owlv2.owlv2_hf import OWLv2HF
45+
from inference_models.models.owlv2.owlv2_hf import (
46+
OWLv2HF,
47+
monkey_patch_vision_encoder_before_compilation,
48+
)
4749

4850
PRELOADED_HF_MODELS = {}
4951

@@ -96,6 +98,9 @@ def __new__(
9698
if OWLV2_COMPILE_MODEL:
9799
logger.info("Compiling OWLv2 model %s", huggingface_id)
98100
torch._dynamo.config.suppress_errors = True
101+
model._model = monkey_patch_vision_encoder_before_compilation(
102+
model._model
103+
)
99104
model._model.owlv2.vision_model = torch.compile(
100105
model._model.owlv2.vision_model
101106
)

inference_models/inference_models/models/owlv2/owlv2_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
def monkey_patch_vision_encoder_before_compilation(
6060
model: Owlv2ForObjectDetection,
6161
) -> Owlv2ForObjectDetection:
62+
# IMPORTANT: This function is used in inference - move it and you will be executed. This import must work.
63+
# It's brittle, but we had no other choice :)
6264
"""
6365
Due to global changes in transformers: https://github.com/huggingface/transformers/pull/43590
6466
our way of compiling owlv2 vision_model turned out invalid.

0 commit comments

Comments
 (0)