-
Notifications
You must be signed in to change notification settings - Fork 972
Expand file tree
/
Copy pathXNNCompiler.cpp
More file actions
2038 lines (1839 loc) · 70 KB
/
XNNCompiler.cpp
File metadata and controls
2038 lines (1839 loc) · 70 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
#include <executorch/backends/xnnpack/runtime/XNNHeader.h>
#include <executorch/backends/xnnpack/serialization/schema_generated.h>
#include <executorch/extension/threadpool/threadpool.h>
#include <executorch/runtime/executor/pte_data_map.h>
#include <xnnpack.h>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wglobal-constructors"
namespace executorch {
namespace backends {
namespace xnnpack {
namespace delegate {
using executorch::ET_RUNTIME_NAMESPACE::NamedDataMap;
using executorch::runtime::Error;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::Result;
/*
* Provide compile-time allocation.
*/
class CompileAllocator {
public:
/*
* Allocate memory which will be automatically freed at the end
* of the compilation process.
*/
void* allocateTemporary(size_t size) {
auto mem = new uint8_t[size];
temporaries_.emplace_back(mem);
return mem;
}
private:
std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
};
// Flatbuffer types
using ValuePtr = const fb_xnnpack::XValue*;
using NodePtr = const fb_xnnpack::XNode*;
using GraphPtr = const fb_xnnpack::XNNGraph*;
using ConstantDataOffsetPtr = const fb_xnnpack::ConstantDataOffset*;
using DataType = fb_xnnpack::XNNDatatype;
// Type for define node function. This is the function signature
// for any function that takes in a flatbuffer node and defines it
// into our xnn_subgraph
using DefineNodeFunc = Error (*)(
xnn_subgraph_t,
const std::unordered_map<uint32_t, uint32_t>&,
NodePtr,
const fb_xnnpack::XNNGraph*) noexcept;
/*
Convert a tensor from fp32 to bf16.
*/
void convertF32TensorToBF16(
const float* f32_data,
uint16_t* bf16_data_out,
size_t numel) {
for (auto i = 0u; i < numel; i++) {
// Adjust the f32 value such that it rounds properly after truncation.
// Constant factor scales 1+2^-8 to 1+2e-7.
float f32_adjusted = f32_data[i] * 1.00389105f;
uint32_t f32_bits;
memcpy(&f32_bits, &f32_adjusted, sizeof(float));
bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
}
}
/*
Gets the output min and output max for a given node operator
*/
std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
float output_min = -std::numeric_limits<float>::infinity();
float output_max = std::numeric_limits<float>::infinity();
auto output_min_max = node->output_min_max();
if (output_min_max != nullptr) {
output_min = output_min_max->output_min();
output_max = output_min_max->output_max();
}
return {output_min, output_max};
}
/*
Converts flatbuffer xnn data type to xnnpack data type
*/
xnn_datatype getDataType(const DataType& data_type) {
switch (data_type) {
case DataType::xnn_datatype_fp32:
return xnn_datatype::xnn_datatype_fp32;
case DataType::xnn_datatype_fp16:
return xnn_datatype::xnn_datatype_fp16;
case DataType::xnn_datatype_qint8:
return xnn_datatype::xnn_datatype_qint8;
case DataType::xnn_datatype_quint8:
return xnn_datatype::xnn_datatype_quint8;
case DataType::xnn_datatype_qint32:
return xnn_datatype::xnn_datatype_qint32;
case DataType::xnn_datatype_qcint8:
return xnn_datatype::xnn_datatype_qcint8;
case DataType::xnn_datatype_qcint32:
return xnn_datatype::xnn_datatype_qcint32;
case DataType::xnn_datatype_qcint4:
return xnn_datatype::xnn_datatype_qcint4;
case DataType::xnn_datatype_qdint8:
return xnn_datatype::xnn_datatype_qdint8;
case DataType::xnn_datatype_qbint4:
return xnn_datatype::xnn_datatype_qbint4;
case DataType::xnn_datatype_qpint8:
return xnn_datatype::xnn_datatype_qpint8;
case DataType::xnn_datatype_int32:
return xnn_datatype::xnn_datatype_int32;
case DataType::xnn_datatype_pfp32:
return xnn_datatype::xnn_datatype_pfp32;
case DataType::xnn_datatype_bf16:
return xnn_datatype::xnn_datatype_bf16;
default:
return xnn_datatype::xnn_datatype_invalid;
}
}
bool isQuantizedDataType(const xnn_datatype data_type) {
switch (data_type) {
case xnn_datatype::xnn_datatype_qint8:
case xnn_datatype::xnn_datatype_quint8:
case xnn_datatype::xnn_datatype_qint32:
case xnn_datatype::xnn_datatype_qcint8:
case xnn_datatype::xnn_datatype_qcint32:
case xnn_datatype::xnn_datatype_qcint4:
case xnn_datatype::xnn_datatype_qdint8:
return true;
default:
return false;
}
}
/**
Converts dims from uint32 to size_t. Takes in a flatbuffer vector
of uint32_t and returns a std::vector of size_t. XNNPACK takes in
dims of size_t* but tensor shape is serialized in flatbuffer as
int32_t. As a result, we need to static cast the shapes to size_t
*/
template <typename T = size_t>
std::vector<T> flatbufferDimsToVector(
const flatbuffers::Vector<uint32_t>* fb_dims) {
std::vector<T> dims_data;
dims_data.reserve(fb_dims->size());
for (auto fb_dim : *fb_dims) {
dims_data.push_back(static_cast<T>(fb_dim));
}
return dims_data;
}
/**
Gets the constant data pointer associated with the given tensor value.
Obtaining the constant data pointer can either be from within the flatbuffer
payload (deprecated) or via offsets to the constant_data_ptr. If no constant
data associated with the tensor value, then returns nullptr.
*/
const uint8_t* getConstantDataPtr(
uint32_t buffer_idx,
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
if (buffer_idx) {
if (!constant_data_ptr) {
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
// window
const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
return constant_buffer[buffer_idx]->storage()->data();
} else {
ConstantDataOffsetPtr constant_data_offset =
flatbuffer_graph->constant_data()->Get(buffer_idx);
uint64_t offset = constant_data_offset->offset();
bool has_named_key = flatbuffers::IsFieldPresent(
constant_data_offset, fb_xnnpack::ConstantDataOffset::VT_NAMED_KEY);
// If there is no tensor name
if (!has_named_key) {
return constant_data_ptr + offset;
} else {
const std::string& data_name = constant_data_offset->named_key()->str();
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
Result<const uint8_t*> data_ptr =
weights_cache->load_unpacked_data(data_name);
if (!data_ptr.ok()) {
ET_LOG(Error, "Failed to load weights from cache");
return nullptr;
}
return data_ptr.get();
#else
Result<FreeableBuffer> buffer =
named_data_map->get_data(data_name.c_str());
if (!buffer.ok()) {
ET_LOG(
Error,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(buffer.error()));
return nullptr;
}
const uint8_t* data_ptr =
static_cast<const uint8_t*>(buffer.get().data());
freeable_buffers.push_back(std::move(buffer.get()));
return data_ptr;
#endif
}
}
}
return nullptr;
}
const uint8_t* getConstantDataPtr(
const fb_xnnpack::XNNTensorValue* tensor_value,
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
return getConstantDataPtr(
tensor_value->constant_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
}
/**
Define serialized tensor value into
the subgraph. While also keeping track of the remapped ids from
the serialized id to the newly generated id.
*/
Error defineTensor(
xnn_subgraph_t subgraph_ptr,
std::unordered_map<uint32_t, uint32_t>& remapped_ids,
ValuePtr value,
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
std::vector<uint32_t>& input_ids,
std::vector<uint32_t>& output_ids,
CompileAllocator& allocator,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
switch (value->xvalue_union_type()) {
case fb_xnnpack::XValueUnion::XNNTensorValue: {
tensor_value = value->xvalue_union_as_XNNTensorValue();
break;
}
case fb_xnnpack::XValueUnion::XNNQuantizedTensorValue: {
qtensor_value = value->xvalue_union_as_XNNQuantizedTensorValue();
tensor_value = qtensor_value->tensor_value();
break;
}
default: {
ET_CHECK_OR_RETURN_ERROR(
false,
NotImplemented,
"Unhandled value type: %s",
fb_xnnpack::EnumNameXValueUnion(value->xvalue_union_type()));
}
}
ET_CHECK_OR_RETURN_ERROR(
tensor_value != nullptr,
Internal,
"Deserialized Tensor is Null, this should never happen");
// Get tensor dims, here we need to use a vector in order
// to properly convert the uint32_t* to size_t*
std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
// XNNPACK Id
uint32_t id = XNN_INVALID_VALUE_ID;
// Get Pointer to constant data from flatbuffer, if its non-constant
// it is a nullptr
const uint8_t* buffer_ptr = getConstantDataPtr(
tensor_value,
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
xnn_status status;
// The type we might have to convert to
auto dq_datatype = getDataType(tensor_value->dq_datatype());
if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
if (dq_datatype != xnn_datatype::xnn_datatype_qint8) {
ET_CHECK_OR_RETURN_ERROR(
false,
Internal,
"Only int8_t is supported for dq_datatype for now, got: %d",
dq_datatype);
} else {
ET_CHECK_OR_RETURN_ERROR(
(tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT),
Internal,
"Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u",
tensor_value->flags());
}
}
if (qtensor_value == nullptr) {
// FP32 tensor
if (!isQuantizedDataType(dq_datatype)) {
// Define non-quantied tensor
status = xnn_define_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/getDataType(tensor_value->datatype()),
/*num_dims=*/tensor_value->num_dims(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&id);
} else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
ET_CHECK_OR_RETURN_ERROR(
isQuantizedDataType(dq_datatype),
Internal,
"Dynamic quantization can only produce supported quantized dtypes");
ET_CHECK_OR_RETURN_ERROR(
tensor_value->external_id() != XNN_INVALID_VALUE_ID,
Internal,
"Dynamic quantization can only work with external inputs for now, got an internal ID");
ET_CHECK_OR_RETURN_ERROR(
buffer_ptr == nullptr,
Internal,
"Dynamic quantization can only work with external inputs for now, got const data");
switch (dq_datatype) {
case xnn_datatype::xnn_datatype_qint8: {
// HACK TO Maintain FC/BC for ASR this will be removed after 01/2024
// When encountering a dynamically quantized tensor via dq_datatype,
// which is the old flow for serializing dynamically quantized linear.
// We replace the definition of a single tensor with a new dynamic
// Quantization pattern. We change the pattern from:
// serialized_qd_input
// to
// (fp32_input --> convert --> qdint8_input)
status = xnn_define_dynamically_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/xnn_datatype_qdint8,
/*num_dims=*/tensor_value->num_dims(),
/*num_nonbatch_dims=*/1, // always do per token quantization
/*dims=*/dims_data.data(),
/*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id
/*flags=*/0, // this is netiher external input or output
/*id_out=*/&id);
// this is the FP16 or FP32 external value that is being dynamically
// quantized
uint32_t float_id;
enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
status = xnn_define_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/fp_datatype,
/*num_dims=*/tensor_value->num_dims(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&float_id);
// Define dynamic conversion from float to qdint8
status = xnn_define_convert(
/*subgraph=*/subgraph_ptr,
/*input_id=*/float_id,
/*output_id=*/id,
/*flags=*/0);
break;
}
default:
ET_CHECK_OR_RETURN_ERROR(
false,
NotImplemented,
"Unhandled Dyanmic Quantization dtype: %d",
dq_datatype);
}
} else {
ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor");
}
} else {
// define tensor for quantized
switch (qtensor_value->quant_params_type()) {
case fb_xnnpack::XNNQuantParams::PerTensorQuant: {
auto qparams = qtensor_value->quant_params_as_PerTensorQuant();
ET_LOG(
Debug,
"define quant tensor (per tensor): buffer_ptr: %p, scale: %f, zp: %d\n",
buffer_ptr,
qparams->scale(),
qparams->zero_point());
status = xnn_define_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/getDataType(tensor_value->datatype()),
/*zero_point=*/qparams->zero_point(),
/*scale=*/qparams->scale(),
/*num_dims=*/tensor_value->num_dims(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&id);
break;
}
case fb_xnnpack::XNNQuantParams::PerChannelQuant: {
auto qparams = qtensor_value->quant_params_as_PerChannelQuant();
enum xnn_datatype dtype = getDataType(tensor_value->datatype());
int32_t zero_point =
(dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0);
ET_LOG(
Debug,
"define quant tensor (per channel): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, dtype: %u, zero_point: %d\n",
buffer_ptr,
qparams->scale()->size(),
qparams->channel_dim(),
dtype,
zero_point);
const float* scale = qparams->scale()->data();
if (qparams->scale_buffer_idx() != 0) {
scale = reinterpret_cast<const float*>(getConstantDataPtr(
qparams->scale_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache));
ET_CHECK_OR_RETURN_ERROR(
scale != nullptr, Internal, "Failed to load scale data.");
}
status = xnn_define_channelwise_quantized_tensor_value_v2(
/*subgraph=*/subgraph_ptr,
/*datatype=*/dtype,
/*zero_point=*/zero_point,
/*scale=*/scale,
/*num_dims=*/tensor_value->num_dims(),
/*channel_dim*/ qparams->channel_dim(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&id);
break;
}
case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: {
xnn_datatype datatype = getDataType(tensor_value->datatype());
ET_CHECK_OR_RETURN_ERROR(
datatype == xnn_datatype::xnn_datatype_qbint4,
Internal,
"Unsupported datatype for per channel group quantization: %d",
datatype);
auto qparams = qtensor_value->quant_params_as_PerChannelGroupQuant();
size_t group_size = qparams->group_size();
size_t output_channels = tensor_value->dims()->Get(0);
size_t input_channels = tensor_value->dims()->Get(1);
const uint16_t* scale_data = nullptr;
uint32_t scale_numel = 0;
// Block scales are preferably serialized as bf16 but can also be
// serialized as fp32 for backwards compatability.
if (qparams->scale_buffer_idx() != 0) {
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
qparams->scale_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache));
ET_CHECK_OR_RETURN_ERROR(
scale_data != nullptr, Internal, "Failed to load scale data.");
scale_numel = qparams->num_scales();
} else {
// Read fp32 scales, convert to bf16.
auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
qparams->scale()->size() * sizeof(uint16_t)));
scale_numel = qparams->scale()->size();
convertF32TensorToBF16(
qparams->scale()->data(), conv_buffer, scale_numel);
scale_data = conv_buffer;
}
ET_CHECK_OR_RETURN_ERROR(
scale_numel == output_channels * input_channels / group_size,
Internal,
"scale size %zu != output channels %zu * group size %zu",
static_cast<size_t>(scale_numel),
output_channels,
group_size);
int32_t zero_point =
(datatype == xnn_datatype::xnn_datatype_qbint4 ? 8 : 0);
ET_LOG(
Debug,
"define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
buffer_ptr,
scale_numel,
qparams->channel_dim(),
group_size,
output_channels,
datatype,
zero_point,
datatype);
status = xnn_define_blockwise_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/datatype,
/*zero_point=*/zero_point,
/*scale=*/scale_data,
/*num_dims=*/tensor_value->num_dims(),
/*channel_dim=*/qparams->channel_dim(),
/*block_size=*/qparams->group_size(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&id);
break;
}
case fb_xnnpack::XNNQuantParams::PerTokenDynamicQuant: {
auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant();
ET_LOG(
Debug,
"define quant tensor (dynamic): num_dims: %i, num_nonbatch_dims: %i\n",
tensor_value->num_dims(),
qparams->num_nonbatch_dims());
ET_CHECK_OR_RETURN_ERROR(
buffer_ptr == nullptr,
Internal,
"Dynamically quantized tensor should not have constant data but found non-nullptr");
status = xnn_define_dynamically_quantized_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/getDataType(tensor_value->datatype()),
/*num_dims=*/tensor_value->num_dims(),
/*num_nonbatch_dims*/ qparams->num_nonbatch_dims(),
/*dims=*/dims_data.data(),
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&id);
break;
}
default: {
ET_CHECK_OR_RETURN_ERROR(
false,
NotImplemented,
"Unhandled Quantization Parameters: %s",
fb_xnnpack::EnumNameXNNQuantParams(
qtensor_value->quant_params_type()));
}
}
}
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to define tensor %i with code: %s",
tensor_value->id_out(),
xnn_status_to_string(status));
// map serialized id to newly generated id
remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
// Add external ids to either list of input or output ids
if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT) {
input_ids.push_back(tensor_value->external_id());
}
if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
output_ids.push_back(tensor_value->external_id());
}
return Error::Ok;
};
#define MAYBE_UNUSED(x) (void)(x)
#ifdef ENABLE_XNNPACK_KLEIDI
bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert);
auto graph_node = node->xnode_union_as_XNNConvert();
auto cvt_output_id = graph_node->output_id();
auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
for (auto value : *graph->xvalues()) {
if (value->xvalue_union_type() !=
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
continue;
}
auto tensor =
value->xvalue_union_as_XNNQuantizedTensorValue()->tensor_value();
if (tensor->id_out() == id) {
return tensor->datatype() == dtype;
}
}
return false;
};
// Check if the output tensor is qint8 else bail early.
if (!check_dtype(cvt_output_id, DataType::xnn_datatype_qdint8)) {
return false;
}
// XNNPACK dtypes which have qp8 support.
const std::vector<DataType> supported_filter_dtypes = {
DataType::xnn_datatype_qbint4,
DataType::xnn_datatype_qcint4,
DataType::xnn_datatype_qcint8};
// Find if the convert output is going to the right linear node.
// Assuming if we can find one valid linear node, then we can use QP8
// for all the linear nodes consuming this convert output.
for (auto node : *graph->xnodes()) {
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
auto linear_node = node->xnode_union_as_XNNFullyConnected();
if (linear_node->input1_id() == cvt_output_id) {
for (auto supported_filter_dtype : supported_filter_dtypes) {
if (check_dtype(linear_node->filter_id(), supported_filter_dtype)) {
return true;
}
}
}
}
}
return false;
}
#endif // ENABLE_XNNPACK_KLEIDI
/*
Define Convert operator Node into the subgraph
*/
Error defineConvertNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* flatbuffer_graph) noexcept {
MAYBE_UNUSED(flatbuffer_graph);
auto graph_node = node->xnode_union_as_XNNConvert();
int32_t flags = graph_node->flags();
#ifdef ENABLE_XNNPACK_KLEIDI
// This is not currently exposed at include/xnnpack.h yet once it is
// we can remove this runtime logic and do this ahead-of-time
#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
if (isQP8(flatbuffer_graph, node)) {
flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
ET_LOG(
Debug,
"Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
node->debug_handle());
}
#endif
xnn_status status = xnn_define_convert(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
flags);
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create convert node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
};
/*
Define serialized linear(fully-connected) node into the subgraph using
the remapped ids to map the serialized ids, to the new ids generated
when defining the tensor values
*/
Error defineFullyConnectedNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNFullyConnected();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_fully_connected(
subgraph_ptr,
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input1_id()),
remapped_ids.at(graph_node->filter_id()),
remapped_ids.at(graph_node->bias_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create linear node %i, with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
};
/*
Define serialized softmax node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining
the tensor value
*/
Error defineSoftmaxNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNSoftmax();
xnn_status status = xnn_define_softmax(
subgraph_ptr,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create softmax node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
Error defineGlobalAvgPooling2dNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNGlobalAvgPooling2d();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_global_average_pooling_2d(
subgraph_ptr,
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create global average pooling node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
Error defineAvgPooling2dNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNAvgPooling2d();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_average_pooling_2d(
subgraph_ptr,
graph_node->padding_top(),
graph_node->padding_right(),
graph_node->padding_bottom(),
graph_node->padding_left(),
graph_node->pooling_height(),
graph_node->pooling_width(),
graph_node->stride_height(),
graph_node->stride_width(),
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create average pooling node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
/*
Define serialized conv2d node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineConv2dNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNConv2d();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_convolution_2d(
subgraph_ptr,
graph_node->padding_top(),
graph_node->padding_right(),
graph_node->padding_bottom(),
graph_node->padding_left(),
graph_node->kernel_height(),
graph_node->kernel_width(),
graph_node->subsampling_height(),
graph_node->subsampling_width(),
graph_node->dilation_height(),
graph_node->dilation_width(),
graph_node->groups(),
graph_node->group_input_channels(),
graph_node->group_output_channels(),
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input1_id()),
remapped_ids.at(graph_node->filter_id()),
remapped_ids.at(graph_node->bias_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create convolution node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
/*
Define serialized conv_transpose2d node into the subgraph, using the remapped
ids to map the serialized ids, to the new ids generated when defining the tensor
value
*/
Error defineConvTranspose2dNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNConvTranspose2d();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_deconvolution_2d(
subgraph_ptr,
graph_node->padding_top(),
graph_node->padding_right(),
graph_node->padding_bottom(),
graph_node->padding_left(),
graph_node->adjustment_height(),
graph_node->adjustment_width(),
graph_node->kernel_height(),
graph_node->kernel_width(),
graph_node->subsampling_height(),
graph_node->subsampling_width(),
graph_node->dilation_height(),
graph_node->dilation_width(),
graph_node->groups(),
graph_node->group_input_channels(),
graph_node->group_output_channels(),
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input1_id()),
remapped_ids.at(graph_node->filter_id()),
remapped_ids.at(graph_node->bias_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create deconvolution node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
/*
Define serialized maxpool2d node into the subgraph, using the remapped ids
to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineMaxPooling2dNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNMaxPooling2d();
std::pair<float, float> min_max = getOutputMinMax(node);
xnn_status status = xnn_define_max_pooling_2d(
subgraph_ptr,
graph_node->padding_top(),
graph_node->padding_right(),
graph_node->padding_bottom(),
graph_node->padding_left(),
graph_node->pooling_height(),
graph_node->pooling_width(),
graph_node->stride_height(),
graph_node->stride_width(),
graph_node->dilation_height(),
graph_node->dilation_width(),
min_max.first,
min_max.second,
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create maxpool2d node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
/*
Define serialized static transpose node into the subgraph, using the remapped
ids to map the serialized ids, to the new ids generated when defining the
tensor value
*/
Error defineStaticTransposeNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
auto graph_node = node->xnode_union_as_XNNStaticTranspose();
// Get tensor dims, we need to convert the uint32_t* to size_t*
std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
xnn_status status = xnn_define_static_transpose(
subgraph_ptr,
graph_node->num_dims(),
dims_data.data(),
remapped_ids.at(graph_node->input_id()),
remapped_ids.at(graph_node->output_id()),
graph_node->flags());
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create static transpose node %i with code: %s",
node->debug_handle(),
xnn_status_to_string(status));
return Error::Ok;
}
/*
Define serialized static resize bilinear 2d node into the subgraph, using the
remapped ids to map the serialized ids, to the new ids generated when defining
the tensor value
*/
Error defineStaticResizeBilinear2DNode(
xnn_subgraph_t subgraph_ptr,
const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
const NodePtr node,
const fb_xnnpack::XNNGraph* graph) noexcept {
MAYBE_UNUSED(graph);
const fb_xnnpack::XNNStaticResizeBilinear2D* graph_node =