@@ -124,31 +124,18 @@ FoundationPose::FoundationPose(std::shared_ptr<inference_core::BaseInferCore>
124124{
125125 // Check
126126 auto refiner_blobs_buffer = refiner_core->GetBuffer (true );
127- if (refiner_blobs_buffer->GetOuterBlobBuffer (RENDER_INPUT_BLOB_NAME ).first == nullptr )
127+ auto scorer_blobs_buffer = scorer_core->GetBuffer (true );
128+ try
128129 {
129- LOG (ERROR ) << " [FoundationPose] Failed to Construct FoundationPose since `renfiner_core` "
130- << " do not has a blob named `" << RENDER_INPUT_BLOB_NAME << " `." ;
131- throw std::runtime_error (" [FoundationPose] Failed to Construct FoundationPose" );
132- }
133- if (refiner_blobs_buffer->GetOuterBlobBuffer (TRANSF_INPUT_BLOB_NAME ).first == nullptr )
134- {
135- LOG (ERROR ) << " [FoundationPose] Failed to Construct FoundationPose since `renfiner_core` "
136- << " do not has a blob named `" << TRANSF_INPUT_BLOB_NAME << " `." ;
137- throw std::runtime_error (" [FoundationPose] Failed to Construct FoundationPose" );
138- }
139-
140- auto scorer_blobs_buffer = scorer_core->GetBuffer (true );
141- if (scorer_blobs_buffer->GetOuterBlobBuffer (RENDER_INPUT_BLOB_NAME ).first == nullptr )
142- {
143- LOG (ERROR ) << " [FoundationPose] Failed to Construct FoundationPose since `scorer_core` "
144- << " do not has a blob named `" << RENDER_INPUT_BLOB_NAME << " `." ;
145- throw std::runtime_error (" [FoundationPose] Failed to Construct FoundationPose" );
146- }
147- if (scorer_blobs_buffer->GetOuterBlobBuffer (TRANSF_INPUT_BLOB_NAME ).first == nullptr )
130+ refiner_blobs_buffer->GetTensor (RENDER_INPUT_BLOB_NAME );
131+ refiner_blobs_buffer->GetTensor (TRANSF_INPUT_BLOB_NAME );
132+ scorer_blobs_buffer->GetTensor (RENDER_INPUT_BLOB_NAME );
133+ scorer_blobs_buffer->GetTensor (TRANSF_INPUT_BLOB_NAME );
134+ } catch (const std::exception &e)
148135 {
149- LOG (ERROR ) << " [FoundationPose] Failed to Construct FoundationPose since `scorer_core` "
150- << " do not has a blob named ` " << TRANSF_INPUT_BLOB_NAME << " `. " ;
151- throw std::runtime_error ( " [FoundationPose] Failed to Construct FoundationPose " );
136+ LOG (ERROR ) << " [FoundationPose] Failed to Construct FoundationPose, ex : " << e. what ();
137+ throw std::runtime_error ( " [FoundationPose] Failed to Construct FoundationPose, ex : " +
138+ std::string (e. what ()) );
152139 }
153140
154141 // preload modules
@@ -210,7 +197,8 @@ bool FoundationPose::Register(const cv::Mat &rgb,
210197 MESSURE_DURATION_AND_CHECK_STATE (UploadDataToDevice (rgb, depth, mask, package),
211198 " [FoundationPose] SyncDetect Failed to upload data!!!" );
212199
213- for (size_t i = 0 ; i < refine_itr ; ++ i) {
200+ for (size_t i = 0 ; i < refine_itr; ++i)
201+ {
214202 MESSURE_DURATION_AND_CHECK_STATE (
215203 RefinePreProcess (package),
216204 " [FoundationPose] SyncDetect Failed to execute RefinePreProcess!!!" );
@@ -258,9 +246,10 @@ bool FoundationPose::Track(const cv::Mat &rgb,
258246 MESSURE_DURATION_AND_CHECK_STATE (UploadDataToDevice (rgb, depth, cv::Mat (), package),
259247 " [FoundationPose] Track Failed to upload data!!!" );
260248
261- for (size_t i = 0 ; i < refine_itr ; ++ i) {
262- MESSURE_DURATION_AND_CHECK_STATE (RefinePreProcess (package),
263- " [FoundationPose] Track Failed to execute RefinePreProcess!!!" );
249+ for (size_t i = 0 ; i < refine_itr; ++i)
250+ {
251+ MESSURE_DURATION_AND_CHECK_STATE (
252+ RefinePreProcess (package), " [FoundationPose] Track Failed to execute RefinePreProcess!!!" );
264253
265254 MESSURE_DURATION_AND_CHECK_STATE (
266255 refiner_core_->SyncInfer (package->GetInferBuffer ()),
@@ -337,44 +326,43 @@ bool FoundationPose::RefinePreProcess(const ParsingType &package)
337326 }
338327
339328 // 2. render
340- if (package->refiner_blobs_buffer == nullptr ) {
329+ if (package->refiner_blobs_buffer == nullptr )
330+ {
341331 package->refiner_blobs_buffer = refiner_core_->GetBuffer (true );
342332 }
343- const auto & refiner_blob_buffer = package->refiner_blobs_buffer ;
333+ const auto &refiner_blobs_buffer = package->refiner_blobs_buffer ;
344334 // 设置推理前blob的输入位置为device,输出的blob位置为host端
345- refiner_blob_buffer-> SetBlobBuffer (RENDER_INPUT_BLOB_NAME , DataLocation::DEVICE );
346- refiner_blob_buffer-> SetBlobBuffer (TRANSF_INPUT_BLOB_NAME , DataLocation::DEVICE );
335+ refiner_blobs_buffer-> GetTensor (RENDER_INPUT_BLOB_NAME )-> SetBufferLocation ( DataLocation::DEVICE );
336+ refiner_blobs_buffer-> GetTensor (TRANSF_INPUT_BLOB_NAME )-> SetBufferLocation ( DataLocation::DEVICE );
347337
348- auto &refine_renderer = map_name2renderer_[package->target_name ];
338+ auto &refine_renderer = map_name2renderer_[package->target_name ];
349339 CHECK_STATE (
350340 refine_renderer->RenderAndTransform (
351341 package->hyp_poses , package->rgb_on_device .get (), package->depth_on_device .get (),
352342 package->xyz_map_on_device .get (), package->input_image_height , package->input_image_width ,
353- refiner_blob_buffer-> GetOuterBlobBuffer (RENDER_INPUT_BLOB_NAME ). first ,
354- refiner_blob_buffer-> GetOuterBlobBuffer (TRANSF_INPUT_BLOB_NAME ). first ,
343+ refiner_blobs_buffer-> GetTensor (RENDER_INPUT_BLOB_NAME )-> RawPtr () ,
344+ refiner_blobs_buffer-> GetTensor (TRANSF_INPUT_BLOB_NAME )-> RawPtr () ,
355345 refine_mode_crop_ratio_),
356346 " [FoundationPose] Failed to render and transform !!!" );
357347 // 3. 设置推理时形状
358- const int input_poses_num = package->hyp_poses .size ();
359- refiner_blob_buffer->SetBlobShape (RENDER_INPUT_BLOB_NAME ,
360- {input_poses_num, crop_window_H_, crop_window_W_, 6 });
361- refiner_blob_buffer->SetBlobShape (TRANSF_INPUT_BLOB_NAME ,
362- {input_poses_num, crop_window_H_, crop_window_W_, 6 });
363- package->infer_buffer = refiner_blob_buffer;
348+ const size_t input_poses_num = package->hyp_poses .size ();
349+ refiner_blobs_buffer->GetTensor (RENDER_INPUT_BLOB_NAME )
350+ ->SetShape ({input_poses_num, static_cast <uint64_t >(crop_window_H_),
351+ static_cast <uint64_t >(crop_window_W_), 6 });
352+ refiner_blobs_buffer->GetTensor (TRANSF_INPUT_BLOB_NAME )
353+ ->SetShape ({input_poses_num, static_cast <uint64_t >(crop_window_H_),
354+ static_cast <uint64_t >(crop_window_W_), 6 });
355+ package->infer_buffer = refiner_blobs_buffer.get ();
364356
365357 return true ;
366358}
367359
368360bool FoundationPose::RefinePostProcess (const ParsingType &package)
369361{
370362 // 获取refiner模型的缓存指针
371- const auto &refiner_blob_buffer = package->refiner_blobs_buffer ;
372- const auto _trans_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_TRANS_OUT_BLOB_NAME ).first ;
373- const auto _rot_ptr = refiner_blob_buffer->GetOuterBlobBuffer (REFINE_ROT_OUT_BLOB_NAME ).first ;
374- const float *trans_ptr = static_cast <float *>(_trans_ptr);
375- const float *rot_ptr = static_cast <float *>(_rot_ptr);
376- CHECK_STATE (trans_ptr != nullptr , " [FoundationPose] RefinePostProcess got invalid trans_ptr !" );
377- CHECK_STATE (rot_ptr != nullptr , " [FoundationPose] RefinePostProcess got invalid rot_ptr !" );
363+ const auto &refiner_blobs_buffer = package->refiner_blobs_buffer ;
364+ const auto trans_ptr = refiner_blobs_buffer->GetTensor (REFINE_TRANS_OUT_BLOB_NAME )->Cast <float >();
365+ const auto rot_ptr = refiner_blobs_buffer->GetTensor (REFINE_ROT_OUT_BLOB_NAME )->Cast <float >();
378366
379367 // 获取生成的假设位姿
380368 const auto &hyp_poses = package->hyp_poses ;
@@ -419,39 +407,39 @@ bool FoundationPose::RefinePostProcess(const ParsingType &package)
419407
420408bool FoundationPose::ScorePreprocess (const ParsingType &package)
421409{
422- auto scorer_blob_buffer = scorer_core_->GetBuffer (false );
410+ auto scorer_blobs_buffer = scorer_core_->GetBuffer (false );
423411 // 获取对应的score_renderer
424412 // 设置推理前后blob输出的位置,这里输入输出都在device端
425- scorer_blob_buffer->SetBlobBuffer (RENDER_INPUT_BLOB_NAME , DataLocation::DEVICE );
426- scorer_blob_buffer->SetBlobBuffer (TRANSF_INPUT_BLOB_NAME , DataLocation::DEVICE );
427- scorer_blob_buffer->SetBlobBuffer (SCORE_OUTPUT_BLOB_NAME , DataLocation::DEVICE );
413+ scorer_blobs_buffer->GetTensor (RENDER_INPUT_BLOB_NAME )->SetBufferLocation (DataLocation::DEVICE );
414+ scorer_blobs_buffer->GetTensor (TRANSF_INPUT_BLOB_NAME )->SetBufferLocation (DataLocation::DEVICE );
415+ scorer_blobs_buffer->GetTensor (SCORE_OUTPUT_BLOB_NAME )->SetBufferLocation (DataLocation::DEVICE );
416+
428417 auto &score_renderer = map_name2renderer_[package->target_name ];
429418 CHECK_STATE (
430419 score_renderer->RenderAndTransform (
431420 package->hyp_poses , package->rgb_on_device .get (), package->depth_on_device .get (),
432421 package->xyz_map_on_device .get (), package->input_image_height , package->input_image_width ,
433- scorer_blob_buffer->GetOuterBlobBuffer (RENDER_INPUT_BLOB_NAME ).first ,
434- scorer_blob_buffer->GetOuterBlobBuffer (TRANSF_INPUT_BLOB_NAME ).first ,
435- score_mode_crop_ratio_),
422+ scorer_blobs_buffer->GetTensor (RENDER_INPUT_BLOB_NAME )->RawPtr (),
423+ scorer_blobs_buffer->GetTensor (TRANSF_INPUT_BLOB_NAME )->RawPtr (), score_mode_crop_ratio_),
436424 " [FoundationPose] score_renderer RenderAndTransform Failed!!!" );
437425
438- package->scorer_blobs_buffer = scorer_blob_buffer ;
439- package->infer_buffer = scorer_blob_buffer ;
426+ package->scorer_blobs_buffer = scorer_blobs_buffer ;
427+ package->infer_buffer = scorer_blobs_buffer. get () ;
440428
441429 return true ;
442430}
443431
444432bool FoundationPose::ScorePostProcess (const ParsingType &package)
445433{
446- const auto &scorer_blob_buffer = package->scorer_blobs_buffer ;
434+ const auto &scorer_blobs_buffer = package->scorer_blobs_buffer ;
447435 // 获取scorer模型的输出缓存指针
448- void * score_ptr = scorer_blob_buffer-> GetOuterBlobBuffer (SCORE_OUTPUT_BLOB_NAME ). first ;
436+ const auto score_ptr = scorer_blobs_buffer-> GetTensor (SCORE_OUTPUT_BLOB_NAME )-> Cast < float >() ;
449437
450438 const auto &refine_poses = package->hyp_poses ;
451439 const int poses_num = refine_poses.size ();
452440
453441 // 获取置信度最大的refined_pose
454- int max_score_index = getMaxScoreIndex (nullptr , reinterpret_cast < float *>( score_ptr) , poses_num);
442+ int max_score_index = getMaxScoreIndex (nullptr , score_ptr, poses_num);
455443 package->actual_pose = refine_poses[max_score_index];
456444
457445 return true ;
0 commit comments