Skip to content

Commit 5719c97

Browse files
committed
WIP: Fix Windowed embedding to copy float data
1 parent 588b0d1 commit 5719c97

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

src/models/embeddings.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
114114
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
115115
std::copy_n(
116116
full_data,
117-
window_size_ * hidden_size,
117+
window_size_ * hidden_size * 2,
118118
embeddings_->GetTensorMutableData<uint16_t>());
119119

120120
} else if (window_index_ < num_windows_) {
@@ -125,17 +125,17 @@ void WindowedEmbeddings::Update(Embeddings& embeddings) {
125125
};
126126
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
127127
std::copy_n(
128-
full_data + window_index_ * window_size_ * hidden_size,
129-
window_size_ * hidden_size,
128+
full_data + window_index_ * window_size_ * hidden_size * 2,
129+
window_size_ * hidden_size * 2,
130130
embeddings_->GetTensorMutableData<uint16_t>());
131131

132132
} else {
133133
// Final token case (e.g., generated token)
134134
shape_ = {1, 1, hidden_size};
135135
embeddings_ = OrtValue::CreateTensor(model_.p_device_inputs_->GetAllocator(), shape_, type_);
136136
std::copy_n(
137-
full_data + (sequence_length - 1) * hidden_size,
138-
hidden_size,
137+
full_data + (sequence_length - 1) * hidden_size * 2,
138+
hidden_size * 2,
139139
embeddings_->GetTensorMutableData<uint16_t>());
140140

141141
}

0 commit comments

Comments
 (0)