@@ -114,7 +114,7 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
114114 embeddings_ = OrtValue::CreateTensor (model_.p_device_inputs_ ->GetAllocator (), shape_, type_);
115115 std::copy_n (
116116 full_data,
117- window_size_ * hidden_size,
117+ window_size_ * hidden_size * 2 ,
118118 embeddings_->GetTensorMutableData <uint16_t >());
119119
120120 } else if (window_index_ < num_windows_) {
@@ -125,17 +125,17 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
125125 };
126126 embeddings_ = OrtValue::CreateTensor (model_.p_device_inputs_ ->GetAllocator (), shape_, type_);
127127 std::copy_n (
128- full_data + window_index_ * window_size_ * hidden_size,
129- window_size_ * hidden_size,
128+ full_data + window_index_ * window_size_ * hidden_size * 2 ,
129+ window_size_ * hidden_size * 2 ,
130130 embeddings_->GetTensorMutableData <uint16_t >());
131131
132132 } else {
133133 // Final token case (e.g., generated token)
134134 shape_ = {1 , 1 , hidden_size};
135135 embeddings_ = OrtValue::CreateTensor (model_.p_device_inputs_ ->GetAllocator (), shape_, type_);
136136 std::copy_n (
137- full_data + (sequence_length - 1 ) * hidden_size,
138- hidden_size,
137+ full_data + (sequence_length - 1 ) * hidden_size * 2 ,
138+ hidden_size * 2 ,
139139 embeddings_->GetTensorMutableData <uint16_t >());
140140
141141 }
0 commit comments