@@ -271,6 +271,95 @@ internal IntPtr MoveHandle()
271271 /// </summary>
272272 public bool is_complex ( ) => torch . is_complex ( dtype ) ;
273273
274+ /// <summary>
275+ /// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32.
276+ /// </summary>
277+ public bool is_quantized ( ) => torch . is_quantized ( dtype ) ;
278+
279+ /// <summary>
280+ /// Given a quantized Tensor, returns a dequantized (float) Tensor.
281+ /// </summary>
282+ public Tensor dequantize ( )
283+ {
284+ var res = NativeMethods . THSTensor_dequantize ( Handle ) ;
285+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
286+ return new Tensor ( res ) ;
287+ }
288+
289+ /// <summary>
290+ /// Given a quantized Tensor, returns the scale of the quantization as a double.
291+ /// </summary>
292+ public double q_scale ( )
293+ {
294+ var res = NativeMethods . THSTensor_q_scale ( Handle ) ;
295+ CheckForErrors ( ) ;
296+ return res ;
297+ }
298+
299+ /// <summary>
300+ /// Given a quantized Tensor, returns the zero_point of the quantization as a long.
301+ /// </summary>
302+ public long q_zero_point ( )
303+ {
304+ var res = NativeMethods . THSTensor_q_zero_point ( Handle ) ;
305+ CheckForErrors ( ) ;
306+ return res ;
307+ }
308+
309+ /// <summary>
310+ /// Given a quantized Tensor, returns a Tensor of the underlying integer representation.
311+ /// </summary>
312+ public Tensor int_repr ( )
313+ {
314+ var res = NativeMethods . THSTensor_int_repr ( Handle ) ;
315+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
316+ return new Tensor ( res ) ;
317+ }
318+
319+ /// <summary>
320+ /// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel.
321+ /// </summary>
322+ public Tensor q_per_channel_scales ( )
323+ {
324+ var res = NativeMethods . THSTensor_q_per_channel_scales ( Handle ) ;
325+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
326+ return new Tensor ( res ) ;
327+ }
328+
329+ /// <summary>
330+ /// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel.
331+ /// </summary>
332+ public Tensor q_per_channel_zero_points ( )
333+ {
334+ var res = NativeMethods . THSTensor_q_per_channel_zero_points ( Handle ) ;
335+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
336+ return new Tensor ( res ) ;
337+ }
338+
339+ /// <summary>
340+ /// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied.
341+ /// </summary>
342+ public long q_per_channel_axis ( )
343+ {
344+ var res = NativeMethods . THSTensor_q_per_channel_axis ( Handle ) ;
345+ CheckForErrors ( ) ;
346+ return res ;
347+ }
348+
349+ internal Tensor _quantize_per_tensor ( double scale , long zero_point , ScalarType dtype )
350+ {
351+ var res = NativeMethods . THSTensor_quantize_per_tensor ( Handle , scale , zero_point , ( sbyte ) dtype ) ;
352+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
353+ return new Tensor ( res ) ;
354+ }
355+
356+ internal Tensor _quantize_per_channel ( Tensor scales , Tensor zero_points , long axis , ScalarType dtype )
357+ {
358+ var res = NativeMethods . THSTensor_quantize_per_channel ( Handle , scales . Handle , zero_points . Handle , axis , ( sbyte ) dtype ) ;
359+ if ( res == IntPtr . Zero ) { CheckForErrors ( ) ; }
360+ return new Tensor ( res ) ;
361+ }
362+
274363 /// <summary>
275364 /// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
276365 /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
@@ -7359,9 +7448,9 @@ public enum ScalarType : sbyte
73597448 ComplexFloat32 = 9 ,
73607449 ComplexFloat64 = 10 ,
73617450 Bool = 11 ,
7362- // QInt8 = 12,
7363- // QUInt8 = 13,
7364- //QUInt32 = 14,
7451+ QInt8 = 12 ,
7452+ QUInt8 = 13 ,
7453+ QInt32 = 14 ,
73657454 BFloat16 = 15
73667455 }
73677456
@@ -7493,6 +7582,18 @@ public static bool is_complex(ScalarType type)
74937582 }
74947583 }
74957584
7585+ public static bool is_quantized ( ScalarType type )
7586+ {
7587+ switch ( type ) {
7588+ case ScalarType . QInt8 :
7589+ case ScalarType . QUInt8 :
7590+ case ScalarType . QInt32 :
7591+ return true ;
7592+ default :
7593+ return false ;
7594+ }
7595+ }
7596+
74967597 public static long max_int_value ( ScalarType type )
74977598 {
74987599 switch ( type ) {
@@ -7543,6 +7644,10 @@ public static long max_int_value(ScalarType type)
75437644 public static ScalarType cfloat = ScalarType . ComplexFloat32 ;
75447645 public static ScalarType cdouble = ScalarType . ComplexFloat64 ;
75457646
7647+ public static ScalarType qint8 = ScalarType . QInt8 ;
7648+ public static ScalarType quint8 = ScalarType . QUInt8 ;
7649+ public static ScalarType qint32 = ScalarType . QInt32 ;
7650+
75467651 /// <summary>
75477652 /// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
75487653 /// be automatically disposed once the dispose scope is disposed.
0 commit comments