Skip to content

Commit 4d36770

Browse files
authored
Merge pull request #20815 from andriiryzhkov/seg_fix
[AI] Fix object mask concurrency and shape-lifetime bugs
2 parents 1017fbc + 6382e43 commit 4d36770

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

src/common/ai/segmentation.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,16 @@ void dt_seg_warmup_decoder(dt_seg_context_t *ctx)
570570
.shape = has_mask_shape, .ndim = 1};
571571

572572
int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
573+
// shapes must outlive outputs[] used by dt_ai_run below
574+
int64_t iou_shape[2] = {1, nm};
575+
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
573576
float iou_buf[MAX_NUM_MASKS];
574577

575578
dt_ai_tensor_t outputs[3];
576579
int n_out;
577580

578581
if(is_sam)
579582
{
580-
int64_t iou_shape[2] = {1, nm};
581-
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
582583
const int dec_outputs = dt_ai_get_output_count(ctx->decoder);
583584

584585
outputs[0] = (dt_ai_tensor_t){
@@ -847,14 +848,16 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
847848
dt_ai_tensor_t dec_outputs[3];
848849
int n_dec_out;
849850
int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
851+
// shapes must outlive dec_outputs[] used by dt_ai_run below
852+
int64_t iou_shape[2] = {1, nm};
853+
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};
850854

851855
float iou_pred[MAX_NUM_MASKS];
852856
float *low_res = NULL;
853857

854858
if(is_sam)
855859
{
856860
// SAM: masks [1,N,H,W] + iou [1,N], optionally low_res [1,N,pm,pm]
857-
int64_t iou_shape[2] = {1, nm};
858861
const int dec_out_count = dt_ai_get_output_count(ctx->decoder);
859862

860863
dec_outputs[0] = (dt_ai_tensor_t){
@@ -875,7 +878,6 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
875878
g_free(masks);
876879
return NULL;
877880
}
878-
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};
879881
dec_outputs[2] = (dt_ai_tensor_t){
880882
.data = low_res, .type = DT_AI_FLOAT, .shape = low_res_shape, .ndim = 4};
881883
n_dec_out = 3;

src/develop/masks/object.c

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,9 @@ static gpointer _encode_thread_func(gpointer data)
425425
dt_seg_disk_cache_save(d->seg, imgid, distort_hash,
426426
rgb, out_w, out_h);
427427

428-
// signal ready immediately so the user can start placing points,
429-
// the warmup below continues on this background thread - if the user
430-
// clicks before it finishes, ORT serializes concurrent Run() calls on
431-
// the same session, so the decode simply waits for the warmup to
432-
// complete first - in practice, users need a moment to position their
433-
// cursor, so the ~1 s warmup usually finishes before the first click
428+
// signal ready so the user can start placing points; warmup continues
429+
// on this thread; _run_decoder joins the thread on the first click to
430+
// avoid a race with warmup on the shared segmentation context
434431
g_atomic_int_set(&d->encode_state, ok ? ENCODE_READY : ENCODE_ERROR);
435432

436433
// warm up decoder with real encoder embeddings so the first user click
@@ -690,6 +687,13 @@ static void _run_decoder(dt_masks_form_gui_t *gui)
690687
if(gui->guipoints_count <= 0)
691688
return;
692689

690+
// wait for encode thread: warmup may still be running after ENCODE_READY
691+
if(d->encode_thread)
692+
{
693+
g_thread_join(d->encode_thread);
694+
d->encode_thread = NULL;
695+
}
696+
693697
dt_gui_cursor_set_busy();
694698

695699
const float *gp = dt_masks_dynbuf_buffer(gui->guipoints);

0 commit comments

Comments
 (0)