|
9 | 9 | #include <executorch/backends/cadence/hifi/kernels/kernels.h> |
10 | 10 | #include <executorch/backends/cadence/hifi/operators/operators.h> |
11 | 11 | #include <executorch/runtime/kernel/kernel_includes.h> |
| 12 | +#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_conv2d.h> |
12 | 13 |
|
13 | 14 | #define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) |
14 | 15 |
|
@@ -435,9 +436,32 @@ void quantized_conv2d_nhwc_out( |
435 | 436 | const Tensor& bias_scale, |
436 | 437 | double output_scale, |
437 | 438 | int64_t output_zero_point, |
438 | | - __ET_UNUSED const Tensor& out_multiplier, |
439 | | - __ET_UNUSED const Tensor& out_shift, |
| 439 | + const Tensor& out_multiplier, |
| 440 | + const Tensor& out_shift, |
440 | 441 | Tensor& out) { |
| 442 | + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) |
| 443 | + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && |
| 444 | + input.scalar_type() == ::executorch::aten::ScalarType::Short && |
| 445 | + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { |
| 446 | + ::impl::generic::native::quantized_conv2d_nhwc_out( |
| 447 | + ctx, |
| 448 | + input, |
| 449 | + weight, |
| 450 | + bias, |
| 451 | + stride, |
| 452 | + padding, |
| 453 | + dilation, |
| 454 | + groups, |
| 455 | + in_zero_point, |
| 456 | + weight_zero_point, |
| 457 | + bias_scale, |
| 458 | + output_scale, |
| 459 | + output_zero_point, |
| 460 | + out_multiplier, |
| 461 | + out_shift, |
| 462 | + out); |
| 463 | + return; |
| 464 | + } |
441 | 465 | const float bias_scale_float = bias_scale.const_data_ptr<float>()[0]; |
442 | 466 | const int32_t weight_zero_point_int = |
443 | 467 | weight_zero_point.const_data_ptr<int32_t>()[0]; |
@@ -502,8 +526,31 @@ void quantized_conv2d_nhwc_per_tensor_out( |
502 | 526 | __ET_UNUSED int64_t out_multiplier, |
503 | 527 | __ET_UNUSED int64_t out_shift, |
504 | 528 | Tensor& out) { |
505 | | - bool optimized = 0; |
| 529 | + // Handle W8A16 heterogeneous type (int16_t activations, int8_t weights) |
| 530 | + if (out.scalar_type() == ::executorch::aten::ScalarType::Short && |
| 531 | + input.scalar_type() == ::executorch::aten::ScalarType::Short && |
| 532 | + weight.scalar_type() == ::executorch::aten::ScalarType::Char) { |
| 533 | + ::impl::generic::native::quantized_conv2d_nhwc_per_tensor_out( |
| 534 | + ctx, |
| 535 | + input, |
| 536 | + weight, |
| 537 | + bias, |
| 538 | + stride, |
| 539 | + padding, |
| 540 | + dilation, |
| 541 | + groups, |
| 542 | + in_zero_point, |
| 543 | + weight_zero_point, |
| 544 | + bias_scale, |
| 545 | + output_scale, |
| 546 | + output_zero_point, |
| 547 | + out_multiplier, |
| 548 | + out_shift, |
| 549 | + out); |
| 550 | + return; |
| 551 | + } |
506 | 552 |
|
| 553 | + bool optimized = 0; |
507 | 554 | if ((input.scalar_type() == ScalarType::Char) || |
508 | 555 | (input.scalar_type() == ScalarType::Byte)) |
509 | 556 | optimized = 1; |
|
0 commit comments