diff --git a/ds/merge_model.py b/ds/merge_model.py index 1a0d363..e6a204c 100644 --- a/ds/merge_model.py +++ b/ds/merge_model.py @@ -260,8 +260,8 @@ def validate_vit_consistency(model, vit_path, img_path): sample_image = Image.open(BytesIO(response.content)).convert("RGB") sample_image = sample_image.resize((560, 560)) - rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}) - processor = CLIPImageProcessor.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, use_fast=True) + rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, torch_dtype=torch.float32) + processor = CLIPImageProcessor.from_pretrained(vit_path, use_fast=True) rice_inputs = processor.preprocess(images=sample_image, return_tensors="pt").to(dtype=model.dtype, device=rice_model.device) rice_model = rice_model.eval()