Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/common/ai/segmentation.c
Original file line number Diff line number Diff line change
Expand Up @@ -570,15 +570,16 @@ void dt_seg_warmup_decoder(dt_seg_context_t *ctx)
.shape = has_mask_shape, .ndim = 1};

int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
// shapes must outlive outputs[] used by dt_ai_run below
int64_t iou_shape[2] = {1, nm};
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
float iou_buf[MAX_NUM_MASKS];

dt_ai_tensor_t outputs[3];
int n_out;

if(is_sam)
{
int64_t iou_shape[2] = {1, nm};
int64_t lr_shape[4] = {1, nm, pm_dim, pm_dim};
const int dec_outputs = dt_ai_get_output_count(ctx->decoder);

outputs[0] = (dt_ai_tensor_t){
Expand Down Expand Up @@ -847,14 +848,16 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
dt_ai_tensor_t dec_outputs[3];
int n_dec_out;
int64_t masks_shape[4] = {1, nm, dec_h, dec_w};
// shapes must outlive dec_outputs[] used by dt_ai_run below
int64_t iou_shape[2] = {1, nm};
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};

float iou_pred[MAX_NUM_MASKS];
float *low_res = NULL;

if(is_sam)
{
// SAM: masks [1,N,H,W] + iou [1,N], optionally low_res [1,N,pm,pm]
int64_t iou_shape[2] = {1, nm};
const int dec_out_count = dt_ai_get_output_count(ctx->decoder);

dec_outputs[0] = (dt_ai_tensor_t){
Expand All @@ -875,7 +878,6 @@ float *dt_seg_compute_mask(dt_seg_context_t *ctx,
g_free(masks);
return NULL;
}
int64_t low_res_shape[4] = {1, nm, pm_dim, pm_dim};
dec_outputs[2] = (dt_ai_tensor_t){
.data = low_res, .type = DT_AI_FLOAT, .shape = low_res_shape, .ndim = 4};
n_dec_out = 3;
Expand Down
16 changes: 10 additions & 6 deletions src/develop/masks/object.c
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,9 @@ static gpointer _encode_thread_func(gpointer data)
dt_seg_disk_cache_save(d->seg, imgid, distort_hash,
rgb, out_w, out_h);

// signal ready immediately so the user can start placing points,
// the warmup below continues on this background thread - if the user
// clicks before it finishes, ORT serializes concurrent Run() calls on
// the same session, so the decode simply waits for the warmup to
// complete first - in practice, users need a moment to position their
// cursor, so the ~1 s warmup usually finishes before the first click
// signal ready so the user can start placing points; warmup continues
// on this thread; _run_decoder joins the thread on the first click to
// avoid a race with warmup on the shared segmentation context
g_atomic_int_set(&d->encode_state, ok ? ENCODE_READY : ENCODE_ERROR);

// warm up decoder with real encoder embeddings so the first user click
Expand Down Expand Up @@ -690,6 +687,13 @@ static void _run_decoder(dt_masks_form_gui_t *gui)
if(gui->guipoints_count <= 0)
return;

// wait for encode thread: warmup may still be running after ENCODE_READY
if(d->encode_thread)
{
g_thread_join(d->encode_thread);
d->encode_thread = NULL;
}

dt_gui_cursor_set_busy();

const float *gp = dt_masks_dynbuf_buffer(gui->guipoints);
Expand Down