Skip to content

Commit 1424053

Browse files
committed
Merge remote-tracking branch 'origin/main' into hari/webgpu_perf_1_full
2 parents edc5074 + 3a071a6 commit 1424053

164 files changed

Lines changed: 9929 additions & 1310 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cmake/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,7 @@ else()
638638
check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE)
639639
check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS)
640640
check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION)
641+
check_cxx_compiler_flag(-Wno-error=character-conversion HAS_NO_ERROR_CHARACTER_CONVERSION)
641642
check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE)
642643
check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION)
643644
check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS)

cmake/external/onnxruntime_external_deps.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE AND NOT onnxruntime_USE_VCPKG)
149149
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
150150
onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64} EXCLUDE_FROM_ALL)
151151
FetchContent_Populate(protoc_binary)
152-
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
152+
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
153153
onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86} EXCLUDE_FROM_ALL)
154154
FetchContent_Populate(protoc_binary)
155-
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*")
155+
elseif(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^aarch64.*")
156156
onnxruntime_fetchcontent_declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64} EXCLUDE_FROM_ALL)
157157
FetchContent_Populate(protoc_binary)
158158
endif()

cmake/onnxruntime_mlas.cmake

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,9 +636,13 @@ else()
636636
enable_language(ASM)
637637
check_cxx_source_compiles("
638638
#ifdef _AIX
639+
#include <sys/systemcfg.h>
640+
#if !defined(POWER_10)
639641
#define POWER_10 0x40000
642+
#endif
643+
#if !defined(POWER_10_ANDUP)
640644
#define POWER_10_ANDUP (POWER_10)
641-
#include <sys/systemcfg.h>
645+
#endif
642646
#define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP)
643647
int main() {
644648
bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31);
@@ -911,6 +915,11 @@ endif()
911915
if(RISCV64 AND MLAS_SOURCE_IS_NOT_SET)
912916
file(GLOB_RECURSE mlas_platform_srcs CONFIGURE_DEPENDS
913917
"${MLAS_SRC_DIR}/scalar/*.cpp")
918+
# Remove scalar depthwise kernel; replaced by the vectorized version
919+
list(REMOVE_ITEM mlas_platform_srcs
920+
"${MLAS_SRC_DIR}/scalar/SconvDepthwiseKernelScalar.cpp")
921+
list(APPEND mlas_platform_srcs
922+
${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp)
914923

915924
if(onnxruntime_USE_RVV)
916925
set(OLD_CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS}")
@@ -932,11 +941,17 @@ endif()
932941
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
933942
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
934943
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
944+
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
945+
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
935946
)
947+
list(REMOVE_ITEM mlas_platform_srcs
948+
"${MLAS_SRC_DIR}/sconv_nchw_depthwise_multiplier_1.cpp")
936949
set_source_files_properties(
937950
${MLAS_SRC_DIR}/riscv64/sgemm_pack_b_rvv.cpp
938951
${MLAS_SRC_DIR}/riscv64/sgemm_kernel_rvv.cpp
939952
${MLAS_SRC_DIR}/riscv64/softmax_kernel_rvv.cpp
953+
${MLAS_SRC_DIR}/riscv64/sconv_depthwise_kernel_rvv.cpp
954+
${MLAS_SRC_DIR}/riscv64/sconv_nchwc_kernel_rvv.cpp
940955
PROPERTIES COMPILE_FLAGS "-march=rv64gcv -mabi=lp64d")
941956
list(APPEND mlas_private_compile_definitions MLAS_USE_RVV=1)
942957
else()

cmake/onnxruntime_unittests.cmake

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ function(filter_test_srcs test_srcs_var)
5050
endfunction()
5151

5252
set(disabled_warnings)
53+
54+
function(onnxruntime_disable_gtest_character_conversion_as_error target_name)
55+
if (HAS_NO_ERROR_CHARACTER_CONVERSION)
56+
target_compile_options(${target_name} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
57+
endif()
58+
endfunction()
59+
5360
function(AddTest)
5461
cmake_parse_arguments(_UT "DYN" "TARGET" "LIBS;SOURCES;DEPENDS;TEST_ARGS" ${ARGN})
5562
list(REMOVE_DUPLICATES _UT_SOURCES)
@@ -170,9 +177,7 @@ function(AddTest)
170177
if (${HAS_NOERROR})
171178
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=uninitialized>")
172179
endif()
173-
if (${HAS_CHARACTER_CONVERSION})
174-
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
175-
endif()
180+
onnxruntime_disable_gtest_character_conversion_as_error(${_UT_TARGET})
176181
endif()
177182

178183
set(TEST_ARGS ${_UT_TEST_ARGS})
@@ -847,9 +852,7 @@ if(MSVC)
847852
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6326>")
848853
else()
849854
target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
850-
if (HAS_CHARACTER_CONVERSION)
851-
target_compile_options(onnxruntime_test_utils PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:-Wno-error=character-conversion>")
852-
endif()
855+
onnxruntime_disable_gtest_character_conversion_as_error(onnxruntime_test_utils)
853856
endif()
854857
if (onnxruntime_USE_NCCL)
855858
target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS})

docs/ContribOperators.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,7 +3761,6 @@ This version of the operator has been available since version 1 of the 'com.micr
37613761
### <a name="com.microsoft.NhwcFusedConv"></a><a name="com.microsoft.nhwcfusedconv">**com.microsoft.NhwcFusedConv**</a>
37623762

37633763
NhwcFusedConv is a Conv operator with optional activation and add operators fused in.
3764-
Only has fp16 implementation as of 2023/04/15.
37653764

37663765
#### Version
37673766

@@ -3792,26 +3791,26 @@ This version of the operator has been available since version 1 of the 'com.micr
37923791

37933792
<dl>
37943793
<dt><tt>X</tt> : T</dt>
3795-
<dd></dd>
3794+
<dd>Input activation tensor in channels-last layout. For 2D convolution this is [N, H, W, C], where N is batch size, H/W are spatial dimensions, and C is the number of input channels.</dd>
37963795
<dt><tt>W</tt> : T</dt>
3797-
<dd></dd>
3796+
<dd>Convolution weight tensor in the standard ONNX Conv filter layout [M, C/group, kH, kW], where M is the number of output channels.</dd>
37983797
<dt><tt>B</tt> (optional) : T</dt>
3799-
<dd></dd>
3798+
<dd>Optional 1D bias tensor of shape [M].</dd>
38003799
<dt><tt>Z</tt> (optional) : T</dt>
3801-
<dd>Tensor to be added to the output, must be the same shape and format as the output tensor.</dd>
3800+
<dd>Optional residual/add tensor in the same channels-last layout and shape as the output tensor Y. For 2D convolution this is [N, out_H, out_W, M].</dd>
38023801
</dl>
38033802

38043803
#### Outputs
38053804

38063805
<dl>
38073806
<dt><tt>Y</tt> : T</dt>
3808-
<dd></dd>
3807+
<dd>Output tensor in channels-last layout. For 2D convolution this is [N, out_H, out_W, M], where M is the number of output channels.</dd>
38093808
</dl>
38103809

38113810
#### Type Constraints
38123811

38133812
<dl>
3814-
<dt><tt>T</tt> : tensor(float16)</dt>
3813+
<dt><tt>T</tt> : tensor(float16), tensor(float)</dt>
38153814
<dd>Constrain input and output types to float tensors</dd>
38163815
</dl>
38173816

docs/OperatorKernels.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,8 @@ The **OpSet Version** column uses the following notation:
810810
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
811811
|LRN|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
812812
|||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
813-
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
813+
|LSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
814+
|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
814815
|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
815816
|LayerNormalization|*in* X:**T**<br> *in* Scale:**T**<br> *in* B:**T**<br> *out* Y:**T**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**<br><br>or<br><br>*in* X:**T**<br> *in* Scale:**V**<br> *in* B:**V**<br> *out* Y:**V**<br> *out* Mean:**U**<br> *out* InvStdDev:**U**|17+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(float)|
816817
|||[1, 16]|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **U** = tensor(double), tensor(float)<br/> **V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
@@ -889,10 +890,14 @@ The **OpSet Version** column uses the following notation:
889890
|RNN|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
890891
|||[14, 21]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
891892
|||[7, 13]|**T** = tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(int32)|
892-
|RandomNormal|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
893-
|RandomNormalLike|*in* input:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
894-
|RandomUniform|*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
895-
|RandomUniformLike|*in* input:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
893+
|RandomNormal|*out* output:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
894+
|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
895+
|RandomNormalLike|*in* input:**T1**<br> *out* output:**T2**|22+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
896+
|||[1, 21]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
897+
|RandomUniform|*out* output:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
898+
|||[1, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
899+
|RandomUniformLike|*in* input:**T1**<br> *out* output:**T2**|22+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
900+
|||[1, 21]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T2** = tensor(double), tensor(float), tensor(float16)|
896901
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* output:**T**|11+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
897902
|Reciprocal|*in* X:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
898903
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|

include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimiz
124124
// Default is an empty string which means no optimizers are disabled.
125125
static const char* const kOrtSessionOptionsDisableSpecifiedOptimizers = "optimization.disable_specified_optimizers";
126126

127+
// Maximum total output size in bytes that the constant folding optimizer is allowed to produce per node.
128+
// Prevents malicious models from causing excessive memory allocation during optimization.
129+
// If the estimated or actual output size of a constant-foldable node exceeds this limit, the node will
130+
// not be constant folded and will instead be executed at runtime.
131+
//
132+
// Option values:
133+
// - A positive integer (as string): Maximum allowed output size in bytes per constant-folded node.
134+
// Default is "1073741824" (1 GB).
135+
// - "0": Disable the size limit (not recommended for untrusted models).
136+
static const char* const kOrtSessionOptionsConstantFoldingMaxOutputSizeInBytes =
137+
"optimization.constant_folding_max_output_size_in_bytes";
138+
127139
// It controls whether to run graph optimizations in loop or not.
128140
//
129141
// "0": disable. Graph Optimization Loop is disabled.

onnxruntime/contrib_ops/cpu/bert/attention_base.h

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <array>
77
#include <vector>
88
#include "core/common/common.h"
9+
#include "core/common/narrow.h"
910
#include "core/providers/cpu/mlas_backend_kernel_selector_config_utils.h"
1011
#ifndef SHARED_PROVIDER
1112
#include "core/framework/op_kernel.h"
@@ -57,11 +58,11 @@ class AttentionBase {
5758
AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) {
5859
int64_t num_heads = 0;
5960
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
60-
num_heads_ = static_cast<int>(num_heads);
61+
num_heads_ = narrow<int>(num_heads);
6162

6263
is_unidirectional_ = info.template GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
6364
do_rotary_ = info.template GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
64-
rotary_embedding_ = static_cast<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
65+
rotary_embedding_ = narrow<int>(info.template GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
6566
mask_filter_value_ = info.template GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
6667
scale_ = info.template GetAttrOrDefault<float>("scale", 0.0f);
6768
if (!info.template GetAttrs<int64_t>("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) {
@@ -222,6 +223,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
222223
"Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'");
223224
}
224225

226+
// Q, K, V are packed along bias_dims[0]. When their hidden sizes are required to be equal,
227+
// bias_dims[0] == 3 * hidden_size must be a multiple of 3.
228+
if (require_same_hidden_size_ && bias_dims[0] % 3 != 0) {
229+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
230+
"Input 'bias' dimension 0 (", bias_dims[0],
231+
") must be a multiple of 3 (Q, K, V are packed and have equal hidden sizes).");
232+
}
233+
225234
int64_t q_hidden_size = bias_dims[0] / static_cast<int64_t>(3);
226235
int64_t k_hidden_size = q_hidden_size;
227236
int64_t v_hidden_size = k_hidden_size;
@@ -241,6 +250,10 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
241250
q_hidden_size = qkv_hidden_sizes_[0];
242251
k_hidden_size = qkv_hidden_sizes_[1];
243252
v_hidden_size = qkv_hidden_sizes_[2];
253+
} else if (q_hidden_size % num_heads_ != 0) {
254+
// Match the error message produced by the qkv_hidden_sizes path above.
255+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
256+
"hidden_size should be divisible by num_heads:", q_hidden_size);
244257
}
245258

246259
int64_t kv_sequence_length = sequence_length;
@@ -282,14 +295,14 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
282295
"Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
283296
}
284297

285-
if (static_cast<int>(past_dims[2]) != num_heads_) {
298+
if (past_dims[2] != num_heads_) {
286299
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
287300
"Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
288301
}
289302

290-
if (static_cast<int>(past_dims[4]) != k_hidden_size / num_heads_) {
303+
if (past_dims[4] != k_hidden_size / num_heads_) {
291304
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
292-
"Inputs 'past' dimension 2 shall have length of ", k_hidden_size / num_heads_);
305+
"Inputs 'past' dimension 4 shall have length of ", k_hidden_size / num_heads_);
293306
}
294307

295308
if (!past_present_share_buffer_) {
@@ -348,17 +361,17 @@ inline Status AttentionBase::CheckInputs(const TensorShape& input_shape,
348361

349362
if (parameters != nullptr) {
350363
AttentionParameters* output_parameters = reinterpret_cast<AttentionParameters*>(parameters);
351-
output_parameters->batch_size = static_cast<int>(batch_size);
352-
output_parameters->sequence_length = static_cast<int>(sequence_length);
353-
output_parameters->past_sequence_length = static_cast<int>(past_sequence_length);
354-
output_parameters->kv_sequence_length = static_cast<int>(kv_sequence_length);
355-
output_parameters->total_sequence_length = static_cast<int>(total_sequence_length);
356-
output_parameters->max_sequence_length = static_cast<int>(max_sequence_length);
357-
output_parameters->input_hidden_size = static_cast<int>(input_hidden_size);
358-
output_parameters->hidden_size = static_cast<int>(q_hidden_size);
359-
output_parameters->v_hidden_size = static_cast<int>(v_hidden_size);
360-
output_parameters->head_size = static_cast<int>(q_hidden_size) / num_heads_;
361-
output_parameters->v_head_size = static_cast<int>(v_hidden_size) / num_heads_;
364+
output_parameters->batch_size = narrow<int>(batch_size);
365+
output_parameters->sequence_length = narrow<int>(sequence_length);
366+
output_parameters->past_sequence_length = narrow<int>(past_sequence_length);
367+
output_parameters->kv_sequence_length = narrow<int>(kv_sequence_length);
368+
output_parameters->total_sequence_length = narrow<int>(total_sequence_length);
369+
output_parameters->max_sequence_length = narrow<int>(max_sequence_length);
370+
output_parameters->input_hidden_size = narrow<int>(input_hidden_size);
371+
output_parameters->hidden_size = narrow<int>(q_hidden_size);
372+
output_parameters->v_hidden_size = narrow<int>(v_hidden_size);
373+
output_parameters->head_size = narrow<int>(q_hidden_size) / num_heads_;
374+
output_parameters->v_head_size = narrow<int>(v_hidden_size) / num_heads_;
362375
output_parameters->num_heads = num_heads_;
363376
output_parameters->is_unidirectional = is_unidirectional_;
364377
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
@@ -398,7 +411,7 @@ inline Tensor* AttentionBase::GetPresent(TOpKernelContext* context,
398411
int head_size,
399412
int kv_sequence_length,
400413
int& past_sequence_length) const {
401-
past_sequence_length = (nullptr != past) ? static_cast<int>(past->Shape().GetDims()[3]) : 0;
414+
past_sequence_length = (nullptr != past) ? narrow<int>(past->Shape().GetDims()[3]) : 0;
402415
std::array<int64_t, 5> present_dims{2, batch_size, num_heads_,
403416
static_cast<int64_t>(kv_sequence_length) + past_sequence_length, head_size};
404417

0 commit comments

Comments
 (0)