@@ -121,15 +121,20 @@ Resampler::Resampler(size_t num_queries,
121121 size_t embed_dim,
122122 size_t num_heads,
123123 size_t kv_dim,
124+ size_t image_size,
125+ size_t patch_size,
124126 const infinicore::DataType &dtype,
125127 const infinicore::Device &device)
126128 : num_queries_(num_queries),
127129 embed_dim_(embed_dim),
128130 num_heads_(num_heads),
129131 kv_dim_(kv_dim),
132+ image_size_(image_size),
133+ patch_size_(patch_size),
130134 use_kv_proj_(kv_dim != embed_dim) {
131135 INFINICORE_NN_PARAMETER_INIT (query, ({num_queries_, embed_dim_}, dtype, device));
132136 INFINICORE_NN_PARAMETER_INIT (proj, ({embed_dim_, embed_dim_}, dtype, device));
137+
133138 INFINICORE_NN_MODULE_INIT (attn, embed_dim_, num_heads_, dtype, device);
134139 INFINICORE_NN_MODULE_INIT (ln_q, embed_dim_, 1e-6 , dtype, device);
135140 INFINICORE_NN_MODULE_INIT (ln_kv, embed_dim_, 1e-6 , dtype, device);
@@ -138,6 +143,15 @@ Resampler::Resampler(size_t num_queries,
138143 if (use_kv_proj_) {
139144 INFINICORE_NN_MODULE_INIT (kv_proj, kv_dim_, embed_dim_, false , dtype, device);
140145 }
146+
147+ // Initialize full 2d embeddings with max size, calculate on cpu and copy to gpu
148+ size_t num_patches = image_size_ / patch_size_;
149+ INFINICORE_NN_BUFFER_INIT (embedding_table, ({num_patches, num_patches, embed_dim_}, dtype, device_));
150+ std::vector<float > buf (num_patches * num_patches * embed_dim_);
151+ compute_2d_sincos_pos_embed (buf.data (), embed_dim_, num_patches, num_patches);
152+ auto embedding_table_cpu = infinicore::Tensor::zeros ({num_patches, num_patches, embed_dim_}, dtype, infinicore::Device::cpu ());
153+ write_pos_embed (embedding_table_cpu->data (), embedding_table_cpu->dtype (), buf.data (), num_patches * num_patches * embed_dim_);
154+ embedding_table_->copy_from (embedding_table_cpu);
141155}
142156
143157infinicore::Tensor Resampler::forward (const infinicore::Tensor &x,
@@ -152,32 +166,22 @@ infinicore::Tensor Resampler::forward(const infinicore::Tensor &x,
152166 kv = ln_kv_->forward (kv);
153167
154168 // Build positional embeddings on CPU
155- std::vector<int64_t > tgt_sizes_host;
156-
157169 auto tgt_cpu = tgt_sizes->to (infinicore::Device::cpu ());
158- auto n = tgt_cpu->numel ();
159- tgt_sizes_host.resize (n);
160- std::memcpy (tgt_sizes_host.data (), tgt_cpu->data (), n * sizeof (int64_t ));
170+ int64_t *tgt_sizes_ptr = (int64_t *)(tgt_cpu->data ());
161171
162- auto pos_cpu = infinicore::Tensor::zeros ({batch_size, seq_len, embed_dim_}, kv->dtype (), infinicore::Device::cpu ());
163- auto *pos_ptr = reinterpret_cast <std::byte *>(pos_cpu->data ());
164- const size_t elem_size = pos_cpu->element_size ();
172+ auto pos_embeddings = infinicore::Tensor::zeros (kv->shape (), kv->dtype (), kv->device ());
165173
166174 for (size_t b = 0 ; b < batch_size; ++b) {
167- size_t tgt_h = 1 ;
168- size_t tgt_w = seq_len;
169- if (!tgt_sizes_host.empty ()) {
170- tgt_h = static_cast <size_t >(tgt_sizes_host[b * 2 ]);
171- tgt_w = static_cast <size_t >(tgt_sizes_host[b * 2 + 1 ]);
172- }
173- const size_t patch_len = tgt_h * tgt_w;
174- std::vector<float > buf (patch_len * embed_dim_);
175- compute_2d_sincos_pos_embed (buf.data (), embed_dim_, tgt_h, tgt_w);
176- write_pos_embed (pos_ptr + b * seq_len * embed_dim_ * elem_size, pos_cpu->dtype (), buf.data (), patch_len * embed_dim_);
175+
176+ auto tgt_h = static_cast <size_t >(tgt_sizes_ptr[b * 2 ]);
177+ auto tgt_w = static_cast <size_t >(tgt_sizes_ptr[b * 2 + 1 ]);
178+
179+ auto src_embeddings = embedding_table_->narrow ({{0 , 0 , tgt_h}, {1 , 0 , tgt_w}});
180+ auto tgt_embeddings = pos_embeddings->narrow ({{0 , b, 1 }, {1 , 0 , tgt_h * tgt_w}})->view ({tgt_h, tgt_w, embed_dim_});
181+ tgt_embeddings->copy_from (src_embeddings);
177182 }
178183
179- auto pos = pos_cpu->to (kv->device ());
180- auto kv_with_pos = infinicore::op::add (kv, pos);
184+ auto kv_with_pos = infinicore::op::add (kv, pos_embeddings);
181185
182186 auto q = ln_q_->forward (query_);
183187 if (q->shape ().size () == 2 ) {
0 commit comments