From d7fc34f114937e02e1cbd1656e79893cd17eb3cc Mon Sep 17 00:00:00 2001 From: inigopm Date: Mon, 2 Mar 2026 11:58:39 +0100 Subject: [PATCH] fix(merge_model): use torch_dtype for MLCDVisionModel and remove invalid device_map from CLIPImageProcessor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CLIPImageProcessor is a processor, not a model — device_map is not a valid argument and raises TypeError at runtime. MLCDVisionModel.from_pretrained should use torch_dtype= (not dtype=) to explicitly set float32. Regression introduced in #52, partially addressed in #88. --- ds/merge_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ds/merge_model.py b/ds/merge_model.py index 1a0d363..e6a204c 100644 --- a/ds/merge_model.py +++ b/ds/merge_model.py @@ -260,8 +260,8 @@ def validate_vit_consistency(model, vit_path, img_path): sample_image = Image.open(BytesIO(response.content)).convert("RGB") sample_image = sample_image.resize((560, 560)) - rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}) - processor = CLIPImageProcessor.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, use_fast=True) + rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, torch_dtype=torch.float32) + processor = CLIPImageProcessor.from_pretrained(vit_path, use_fast=True) rice_inputs = processor.preprocess(images=sample_image, return_tensors="pt").to(dtype=model.dtype, device=rice_model.device) rice_model = rice_model.eval()