diff --git a/tests/test_calibration_data_device.py b/tests/test_calibration_data_device.py index 6a75822f6..d8dfcd378 100644 --- a/tests/test_calibration_data_device.py +++ b/tests/test_calibration_data_device.py @@ -16,13 +16,11 @@ """ import os -import types -import unittest - import pytest import torch import torch.nn as nn - +import types +import unittest from gptqmodel.models.base import BaseQModel from gptqmodel.quantization import QuantizeConfig @@ -210,6 +208,7 @@ class FakeGPTQModel: move_input_capture_example = BaseQModel.move_input_capture_example prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs run_input_capture = BaseQModel.run_input_capture + _sanitize_input_ids_for_embeddings = BaseQModel._sanitize_input_ids_for_embeddings def __init__(self): self.quantize_config = types.SimpleNamespace( @@ -355,6 +354,7 @@ class FakeGPTQModel: move_input_capture_example = BaseQModel.move_input_capture_example prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs run_input_capture = BaseQModel.run_input_capture + _sanitize_input_ids_for_embeddings = BaseQModel._sanitize_input_ids_for_embeddings def __init__(self): self.quantize_config = types.SimpleNamespace( @@ -509,6 +509,7 @@ class FakeGPTQModel: move_input_capture_example = BaseQModel.move_input_capture_example prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs run_input_capture = BaseQModel.run_input_capture + _sanitize_input_ids_for_embeddings = BaseQModel._sanitize_input_ids_for_embeddings def __init__(self): self.quantize_config = types.SimpleNamespace(