Skip to content

Commit ee6ea6d

Browse files
Add QInt8, QUInt8, QInt32 quantized scalar types with full quantizati… (#1531)
* Add QInt8, QUInt8, QInt32 quantized scalar types with full quantization support - Uncomment and fix ScalarType enum entries (QInt8=12, QUInt8=13, QInt32=14) - Fix incorrect QUInt32 name to QInt32 to match PyTorch - Add torch.qint8, torch.quint8, torch.qint32 dtype aliases - Add torch.is_quantized() and Tensor.is_quantized() methods - Add IsQuantized() extension method on ScalarType Native quantization bindings: - Add THSTensor_quantize_per_tensor, THSTensor_quantize_per_channel (C++ and P/Invoke) - Add THSTensor_dequantize (C++ and P/Invoke) - Add THSTensor_q_scale, THSTensor_q_zero_point (C++ and P/Invoke) - Add THSTensor_int_repr (C++ and P/Invoke) - Add THSTensor_q_per_channel_scales, THSTensor_q_per_channel_zero_points, THSTensor_q_per_channel_axis (C++ and P/Invoke) Managed API: - Add torch.quantize_per_tensor() and torch.quantize_per_channel() static methods - Add torch.dequantize() static method - Add Tensor.dequantize(), Tensor.q_scale(), Tensor.q_zero_point() instance methods - Add Tensor.int_repr() instance method - Add Tensor.q_per_channel_scales(), Tensor.q_per_channel_zero_points(), Tensor.q_per_channel_axis() instance methods Unit tests for all new functionality (13 tests) * Address PR review: add quantized type ElementSize + fix tensor dispose leak - Add QInt8 (1), QUInt8 (1), QInt32 (4) to ElementSize() so it no longer throws NotImplementedException for quantized types - Split chained tensor().reshape() in QuantizePerChannel test to properly dispose the intermediate 1D tensor via 'using var' Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 38988a2 commit ee6ea6d

7 files changed

Lines changed: 472 additions & 3 deletions

File tree

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,3 +2306,48 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_
23062306

23072307
return nullptr;
23082308
}
2309+
2310+
Tensor THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type)
2311+
{
2312+
CATCH_TENSOR(torch::quantize_per_tensor(*tensor, scale, zero_point, at::ScalarType(scalar_type)));
2313+
}
2314+
2315+
Tensor THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type)
2316+
{
2317+
CATCH_TENSOR(torch::quantize_per_channel(*tensor, *scales, *zero_points, axis, at::ScalarType(scalar_type)));
2318+
}
2319+
2320+
Tensor THSTensor_dequantize(const Tensor tensor)
2321+
{
2322+
CATCH_TENSOR(tensor->dequantize());
2323+
}
2324+
2325+
double THSTensor_q_scale(const Tensor tensor)
2326+
{
2327+
CATCH_RETURN(double, 0.0, tensor->q_scale());
2328+
}
2329+
2330+
int64_t THSTensor_q_zero_point(const Tensor tensor)
2331+
{
2332+
CATCH_RETURN(int64_t, 0, tensor->q_zero_point());
2333+
}
2334+
2335+
Tensor THSTensor_int_repr(const Tensor tensor)
2336+
{
2337+
CATCH_TENSOR(tensor->int_repr());
2338+
}
2339+
2340+
Tensor THSTensor_q_per_channel_scales(const Tensor tensor)
2341+
{
2342+
CATCH_TENSOR(tensor->q_per_channel_scales());
2343+
}
2344+
2345+
Tensor THSTensor_q_per_channel_zero_points(const Tensor tensor)
2346+
{
2347+
CATCH_TENSOR(tensor->q_per_channel_zero_points());
2348+
}
2349+
2350+
int64_t THSTensor_q_per_channel_axis(const Tensor tensor)
2351+
{
2352+
CATCH_RETURN(int64_t, 0, tensor->q_per_channel_axis());
2353+
}

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,3 +1808,15 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou
18081808

18091809
EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
18101810
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);
1811+
1812+
// Quantization Ops
1813+
1814+
EXPORT_API(Tensor) THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type);
1815+
EXPORT_API(Tensor) THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type);
1816+
EXPORT_API(Tensor) THSTensor_dequantize(const Tensor tensor);
1817+
EXPORT_API(double) THSTensor_q_scale(const Tensor tensor);
1818+
EXPORT_API(int64_t) THSTensor_q_zero_point(const Tensor tensor);
1819+
EXPORT_API(Tensor) THSTensor_int_repr(const Tensor tensor);
1820+
EXPORT_API(Tensor) THSTensor_q_per_channel_scales(const Tensor tensor);
1821+
EXPORT_API(Tensor) THSTensor_q_per_channel_zero_points(const Tensor tensor);
1822+
EXPORT_API(int64_t) THSTensor_q_per_channel_axis(const Tensor tensor);

src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,6 +2206,33 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
22062206
internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
22072207
[DllImport("LibTorchSharp")]
22082208
internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
2209+
2210+
[DllImport("LibTorchSharp")]
2211+
internal static extern IntPtr THSTensor_quantize_per_tensor(IntPtr tensor, double scale, long zero_point, sbyte scalar_type);
2212+
2213+
[DllImport("LibTorchSharp")]
2214+
internal static extern IntPtr THSTensor_quantize_per_channel(IntPtr tensor, IntPtr scales, IntPtr zero_points, long axis, sbyte scalar_type);
2215+
2216+
[DllImport("LibTorchSharp")]
2217+
internal static extern IntPtr THSTensor_dequantize(IntPtr tensor);
2218+
2219+
[DllImport("LibTorchSharp")]
2220+
internal static extern double THSTensor_q_scale(IntPtr tensor);
2221+
2222+
[DllImport("LibTorchSharp")]
2223+
internal static extern long THSTensor_q_zero_point(IntPtr tensor);
2224+
2225+
[DllImport("LibTorchSharp")]
2226+
internal static extern IntPtr THSTensor_int_repr(IntPtr tensor);
2227+
2228+
[DllImport("LibTorchSharp")]
2229+
internal static extern IntPtr THSTensor_q_per_channel_scales(IntPtr tensor);
2230+
2231+
[DllImport("LibTorchSharp")]
2232+
internal static extern IntPtr THSTensor_q_per_channel_zero_points(IntPtr tensor);
2233+
2234+
[DllImport("LibTorchSharp")]
2235+
internal static extern long THSTensor_q_per_channel_axis(IntPtr tensor);
22092236
}
22102237
#pragma warning restore CA2101
22112238
}

src/TorchSharp/Tensor/Tensor.cs

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,95 @@ internal IntPtr MoveHandle()
271271
/// </summary>
272272
public bool is_complex() => torch.is_complex(dtype);
273273

274+
/// <summary>
275+
/// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32.
276+
/// </summary>
277+
public bool is_quantized() => torch.is_quantized(dtype);
278+
279+
/// <summary>
280+
/// Given a quantized Tensor, returns a dequantized (float) Tensor.
281+
/// </summary>
282+
public Tensor dequantize()
283+
{
284+
var res = NativeMethods.THSTensor_dequantize(Handle);
285+
if (res == IntPtr.Zero) { CheckForErrors(); }
286+
return new Tensor(res);
287+
}
288+
289+
/// <summary>
290+
/// Given a quantized Tensor, returns the scale of the quantization as a double.
291+
/// </summary>
292+
public double q_scale()
293+
{
294+
var res = NativeMethods.THSTensor_q_scale(Handle);
295+
CheckForErrors();
296+
return res;
297+
}
298+
299+
/// <summary>
300+
/// Given a quantized Tensor, returns the zero_point of the quantization as a long.
301+
/// </summary>
302+
public long q_zero_point()
303+
{
304+
var res = NativeMethods.THSTensor_q_zero_point(Handle);
305+
CheckForErrors();
306+
return res;
307+
}
308+
309+
/// <summary>
310+
/// Given a quantized Tensor, returns a Tensor of the underlying integer representation.
311+
/// </summary>
312+
public Tensor int_repr()
313+
{
314+
var res = NativeMethods.THSTensor_int_repr(Handle);
315+
if (res == IntPtr.Zero) { CheckForErrors(); }
316+
return new Tensor(res);
317+
}
318+
319+
/// <summary>
320+
/// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel.
321+
/// </summary>
322+
public Tensor q_per_channel_scales()
323+
{
324+
var res = NativeMethods.THSTensor_q_per_channel_scales(Handle);
325+
if (res == IntPtr.Zero) { CheckForErrors(); }
326+
return new Tensor(res);
327+
}
328+
329+
/// <summary>
330+
/// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel.
331+
/// </summary>
332+
public Tensor q_per_channel_zero_points()
333+
{
334+
var res = NativeMethods.THSTensor_q_per_channel_zero_points(Handle);
335+
if (res == IntPtr.Zero) { CheckForErrors(); }
336+
return new Tensor(res);
337+
}
338+
339+
/// <summary>
340+
/// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied.
341+
/// </summary>
342+
public long q_per_channel_axis()
343+
{
344+
var res = NativeMethods.THSTensor_q_per_channel_axis(Handle);
345+
CheckForErrors();
346+
return res;
347+
}
348+
349+
internal Tensor _quantize_per_tensor(double scale, long zero_point, ScalarType dtype)
350+
{
351+
var res = NativeMethods.THSTensor_quantize_per_tensor(Handle, scale, zero_point, (sbyte)dtype);
352+
if (res == IntPtr.Zero) { CheckForErrors(); }
353+
return new Tensor(res);
354+
}
355+
356+
internal Tensor _quantize_per_channel(Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
357+
{
358+
var res = NativeMethods.THSTensor_quantize_per_channel(Handle, scales.Handle, zero_points.Handle, axis, (sbyte)dtype);
359+
if (res == IntPtr.Zero) { CheckForErrors(); }
360+
return new Tensor(res);
361+
}
362+
274363
/// <summary>
275364
/// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
276365
/// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
@@ -7359,9 +7448,9 @@ public enum ScalarType : sbyte
73597448
ComplexFloat32 = 9,
73607449
ComplexFloat64 = 10,
73617450
Bool = 11,
7362-
//QInt8 = 12,
7363-
//QUInt8 = 13,
7364-
//QUInt32 = 14,
7451+
QInt8 = 12,
7452+
QUInt8 = 13,
7453+
QInt32 = 14,
73657454
BFloat16 = 15
73667455
}
73677456

@@ -7493,6 +7582,18 @@ public static bool is_complex(ScalarType type)
74937582
}
74947583
}
74957584

7585+
public static bool is_quantized(ScalarType type)
7586+
{
7587+
switch (type) {
7588+
case ScalarType.QInt8:
7589+
case ScalarType.QUInt8:
7590+
case ScalarType.QInt32:
7591+
return true;
7592+
default:
7593+
return false;
7594+
}
7595+
}
7596+
74967597
public static long max_int_value(ScalarType type)
74977598
{
74987599
switch (type) {
@@ -7543,6 +7644,10 @@ public static long max_int_value(ScalarType type)
75437644
public static ScalarType cfloat = ScalarType.ComplexFloat32;
75447645
public static ScalarType cdouble = ScalarType.ComplexFloat64;
75457646

7647+
public static ScalarType qint8 = ScalarType.QInt8;
7648+
public static ScalarType quint8 = ScalarType.QUInt8;
7649+
public static ScalarType qint32 = ScalarType.QInt32;
7650+
75467651
/// <summary>
75477652
/// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
75487653
/// be automatically disposed once the dispose scope is disposed.

src/TorchSharp/Tensor/TensorExtensionMethods.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ public static int ElementSize(this ScalarType type)
303303
ScalarType.ComplexFloat32 => 8,
304304
ScalarType.ComplexFloat64 => 16,
305305
ScalarType.Bool => 1,
306+
ScalarType.QInt8 => 1,
307+
ScalarType.QUInt8 => 1,
308+
ScalarType.QInt32 => 4,
306309
ScalarType.BFloat16 => 2,
307310
_ => throw new NotImplementedException()
308311
};
@@ -368,6 +371,23 @@ internal static bool IsComplex(this ScalarType type)
368371
}
369372
}
370373

374+
/// <summary>
375+
/// Indicates whether a given element type is quantized.
376+
/// </summary>
377+
/// <param name="type">The input type.</param>
378+
/// <returns></returns>
379+
internal static bool IsQuantized(this ScalarType type)
380+
{
381+
switch (type) {
382+
case ScalarType.QInt8:
383+
case ScalarType.QUInt8:
384+
case ScalarType.QInt32:
385+
return true;
386+
default:
387+
return false;
388+
}
389+
}
390+
371391
/// <summary>
372392
/// Save the tensor in a .NET-specific format.
373393
/// </summary>

src/TorchSharp/Tensor/torch.PointwiseOps.cs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,47 @@ public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale
779779
public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, Tensor zero_point, long quant_min, long quant_max)
780780
=> throw new NotImplementedException();
781781

782+
// https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor
783+
/// <summary>
784+
/// Converts a float tensor to a quantized tensor with given scale and zero point.
785+
/// </summary>
786+
/// <param name="input">Float tensor to quantize</param>
787+
/// <param name="scale">Scale to apply in quantization formula</param>
788+
/// <param name="zero_point">Offset in integer value that maps to float zero</param>
789+
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
790+
/// <returns>A newly quantized tensor</returns>
791+
public static Tensor quantize_per_tensor(Tensor input, double scale, long zero_point, ScalarType dtype)
792+
{
793+
if (!is_quantized(dtype))
794+
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
795+
return input._quantize_per_tensor(scale, zero_point, dtype);
796+
}
797+
798+
// https://pytorch.org/docs/stable/generated/torch.quantize_per_channel
799+
/// <summary>
800+
/// Converts a float tensor to a per-channel quantized tensor with given scales and zero points.
801+
/// </summary>
802+
/// <param name="input">Float tensor to quantize</param>
803+
/// <param name="scales">Float 1D tensor of scales to use, size should match input.size(axis)</param>
804+
/// <param name="zero_points">Integer 1D tensor of offsets to use, size should match input.size(axis)</param>
805+
/// <param name="axis">Dimension on which to apply per-channel quantization</param>
806+
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
807+
/// <returns>A newly quantized tensor</returns>
808+
public static Tensor quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
809+
{
810+
if (!is_quantized(dtype))
811+
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
812+
return input._quantize_per_channel(scales, zero_points, axis, dtype);
813+
}
814+
815+
// https://pytorch.org/docs/stable/generated/torch.dequantize
816+
/// <summary>
817+
/// Returns an fp32 Tensor by dequantizing a quantized Tensor.
818+
/// </summary>
819+
/// <param name="input">A quantized tensor</param>
820+
/// <returns>A dequantized (float) tensor</returns>
821+
public static Tensor dequantize(Tensor input) => input.dequantize();
822+
782823
// https://pytorch.org/docs/stable/generated/torch.fix
783824
/// <summary>
784825
/// Returns a new tensor with the truncated integer values of the elements of input.

0 commit comments

Comments
 (0)