|
13 | 13 | # limitations under the License. |
14 | 14 | """Pytest configuration file with fixtures for add-ons functionality testing""" |
15 | 15 |
|
| 16 | +# Standard |
| 17 | +from pathlib import Path |
| 18 | + |
16 | 19 | # Third Party |
17 | 20 | import pytest |
18 | 21 | import torch |
@@ -67,65 +70,58 @@ def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]: |
67 | 70 |
|
68 | 71 | i8i8_metadata = [ |
69 | 72 | { |
70 | | - "bs": 4, |
71 | | - "seq_len": 7, |
72 | | - "hid_dim": 256, |
73 | | - "out_feat": 512, |
74 | | - "dtype": torch.float16, |
75 | 73 | "wtype": "per_tensor", # per_channel |
76 | 74 | "atype": "per_tensor_symm", # per_tensor_asymm, per_token |
77 | 75 | "smoothquant": False, |
78 | | - } |
| 76 | + }, |
| 77 | + # { |
| 78 | + # "wtype": "per_channel", # per_channel |
| 79 | + # "atype": "per_tensor_symm", # per_tensor_asymm, per_token |
| 80 | + # "smoothquant": False, |
| 81 | + # }, |
79 | 82 | ] |
80 | 83 |
|
81 | 84 |
|
82 | 85 | @pytest.fixture(scope="session", params=i8i8_metadata) |
83 | 86 | def get_i8i8_gemm_inputs( |
84 | 87 | request, |
85 | | -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool]: |
| 88 | +) -> tuple[ |
| 89 | + torch.Tensor, |
| 90 | + torch.Tensor, |
| 91 | + torch.Tensor, |
| 92 | + torch.Tensor, |
| 93 | + str, |
| 94 | + str, |
| 95 | + bool, |
| 96 | + torch.Tensor, |
| 97 | +]: |
86 | 98 | """pytest fixture returning test inputs for INT8xINT8 op""" |
87 | 99 |
|
88 | 100 | data = request.param |
89 | | - x = torch.randn( |
90 | | - (data["bs"], data["seq_len"], data["hid_dim"]), |
91 | | - dtype=data["dtype"], |
92 | | - ).clamp(-1, 1) |
93 | | - w_int = torch.randint( |
94 | | - low=-8, |
95 | | - high=8, |
96 | | - size=(data["out_feat"], data["hid_dim"]), |
97 | | - dtype=torch.int8, |
| 101 | + |
| 102 | + filename = ( |
| 103 | + f"ref_w-{data['wtype']}_" |
| 104 | + f"a-{data['atype']}_" |
| 105 | + f"sq-{'Y' if data['smoothquant'] else 'N'}.pt" |
98 | 106 | ) |
99 | | - b = torch.zeros(data["out_feat"], dtype=data["dtype"]) |
100 | | - qdata = create_qdata( |
101 | | - data["wtype"], |
102 | | - data["atype"], |
103 | | - data["hid_dim"], |
104 | | - data["out_feat"], |
105 | | - data["smoothquant"], |
106 | | - data["dtype"], |
| 107 | + addon_references = Path("tests/artifacts/aiu_addons") |
| 108 | + i8i8_data = torch.load(addon_references / filename, weights_only=True) |
| 109 | + |
| 110 | + assert isinstance(i8i8_data, dict) |
| 111 | + assert data["wtype"] == i8i8_data["weight_quant_type"] |
| 112 | + assert data["atype"] == i8i8_data["activ_quant_type"] |
| 113 | + assert data["smoothquant"] == i8i8_data["smoothquant"] |
| 114 | + assert all( |
| 115 | + item in i8i8_data for item in ["x", "w_int", "bias", "qdata", "reference_out"] |
107 | 116 | ) |
108 | 117 |
|
109 | | - return (x, w_int, b, qdata, data["wtype"], data["atype"], data["smoothquant"]) |
110 | | - |
111 | | - |
112 | | -def create_qdata( |
113 | | - wtype: str, |
114 | | - atype: str, |
115 | | - in_feat: int, |
116 | | - out_feat: int, |
117 | | - smoothquant: bool, |
118 | | - dtype: torch.dtype, |
119 | | -) -> torch.Tensor: |
120 | | - """Generate dummy qdata tensor based on the provided quantization configuration""" |
121 | | - |
122 | | - qdata_len = 2 if wtype == "per_tensor" else 2 * out_feat # weight clips |
123 | | - qdata_len += 2 # activation clips |
124 | | - qdata_len += out_feat if atype == "per_tensor_asymm" else 1 # zero shift |
125 | | - qdata_len += in_feat if smoothquant else 1 # smoothquant scales |
126 | | - |
127 | | - # TODO: improve dummy generation |
128 | | - qdata = torch.ones(qdata_len, dtype=dtype) |
129 | | - qdata[1] = -qdata[0] # !!! temporary solution to enforce clip symmetry |
130 | | - qdata[3] = -qdata[2] |
131 | | - return qdata |
| 118 | + return ( |
| 119 | + i8i8_data["x"], |
| 120 | + i8i8_data["w_int"], |
| 121 | + i8i8_data["bias"], |
| 122 | + i8i8_data["qdata"], |
| 123 | + i8i8_data["weight_quant_type"], |
| 124 | + i8i8_data["activ_quant_type"], |
| 125 | + i8i8_data["smoothquant"], |
| 126 | + i8i8_data["reference_out"], |
| 127 | + ) |
0 commit comments