@@ -109,6 +109,18 @@ def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
109109
110110
111111def pack_weight_to_int8 (weight ):
112+ """Pack two INT4 values into one INT8 byte (CPU, numpy-based).
113+
114+ Original implementation using Python loops for packing.
115+ Kept for debugging and fallback.
116+ For GPU-accelerated packing, use pack_weight_to_int8_gpu.
117+
118+ Args:
119+ weight: Tensor of shape (out_features, in_features) with values in [-8, 7].
120+
121+ Returns:
122+ Packed INT8 tensor of shape (out_features, in_features // 2) on CPU.
123+ """
112124 weight = weight .t ().contiguous ().cpu ()
113125 weight = weight .to (torch .float32 ).numpy ().astype (np .int8 )
114126
@@ -124,3 +136,34 @@ def pack_weight_to_int8(weight):
124136 packed_weight = packed_weight .astype (np .int8 )
125137 packed_weight = torch .from_numpy (packed_weight ).t ().contiguous ()
126138 return packed_weight
139+
140+
141+ def pack_weight_to_int8_gpu (weight ):
142+ """Pack two INT4 values into one INT8 byte using pure PyTorch (GPU-accelerated).
143+
144+ Supports both CPU and GPU tensors — no numpy dependency, so packing
145+ can be done directly on GPU without device transfer overhead.
146+
147+ Input layout (after transpose): rows are paired (row 0,1 -> packed row 0, etc.)
148+ Low nibble = even row, high nibble = odd row.
149+
150+ Args:
151+ weight: Tensor of shape (out_features, in_features) with values in [-8, 7].
152+ Can be on any device (CPU or CUDA).
153+
154+ Returns:
155+ Packed INT8 tensor of shape (out_features, in_features // 2),
156+ on the same device as input.
157+ """
158+ # Transpose to (in_features, out_features) for row-pair packing
159+ weight = weight .t ().contiguous ().to (torch .int8 )
160+
161+ # Vectorized packing: pair adjacent rows and combine low/high nibbles
162+ # Even rows -> low nibble, odd rows -> high nibble
163+ even_rows = weight [0 ::2 ] # shape: (rows//2, cols)
164+ odd_rows = weight [1 ::2 ] # shape: (rows//2, cols)
165+ packed_weight = (even_rows & 0x0F ) | ((odd_rows & 0x0F ) << 4 )
166+
167+ # Transpose back to (out_features, in_features // 2)
168+ packed_weight = packed_weight .t ().contiguous ()
169+ return packed_weight
0 commit comments