Skip to content

Commit 4fd48ed

Browse files
committed
Tests are green
1 parent 9909a69 commit 4fd48ed

10 files changed

Lines changed: 89 additions & 70 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,7 @@ test/perf/*.jpeg
4141
# ML model
4242
*.pt
4343
*.onnx
44+
45+
# Tools
46+
mise.toml
47+
.tool-versions

config/test.exs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import Config
22

3-
config :nx,
4-
default_backend: EXLA.Backend
5-
63
config :logger,
74
level: :warning
85

6+
config :nx, :default_defn_options, [compiler: EXLA]

lib/classification.ex

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ if ImageVision.bumblebee_configured?() do
157157
@spec classifier(configuration :: Keyword.t()) ::
158158
{Nx.Serving, Keyword.t()} | {:error, Image.error()}
159159
def classifier(classifier \\ Application.get_env(:image_vision, :classifier, [])) do
160-
Application.ensure_all_started(:exla)
161160
classifier = Keyword.merge(@default_classifier, classifier)
162161

163162
model = Keyword.fetch!(classifier, :model)
@@ -224,7 +223,6 @@ if ImageVision.bumblebee_configured?() do
224223
@spec embedder(configuration :: Keyword.t()) ::
225224
{Nx.Serving, Keyword.t()} | {:error, Image.error()}
226225
def embedder(embedder \\ Application.get_env(:image_vision, :embedder, [])) do
227-
Application.ensure_all_started(:exla)
228226
embedder = Keyword.merge(@default_embedder, embedder)
229227

230228
model = Keyword.fetch!(embedder, :model)
@@ -256,7 +254,7 @@ if ImageVision.bumblebee_configured?() do
256254
{:ok, featurizer} = Bumblebee.load_featurizer(featurizer, featurizer_options) do
257255
Bumblebee.Vision.image_classification(model_info, featurizer,
258256
compile: [batch_size: batch_size],
259-
defn_options: [compiler: EXLA]
257+
defn_options: defn_options()
260258
)
261259
end
262260
end
@@ -267,11 +265,23 @@ if ImageVision.bumblebee_configured?() do
267265
{:ok, featurizer} = Bumblebee.load_featurizer(featurizer, featurizer_options) do
268266
Bumblebee.Vision.image_embedding(model_info, featurizer,
269267
compile: [batch_size: batch_size],
270-
defn_options: [compiler: EXLA]
268+
defn_options: defn_options()
271269
)
272270
end
273271
end
274272

273+
# Use EXLA as the Nx compiler when it is properly loaded and implements
274+
# the current Nx.Defn.Compiler protocol. Falls back to the default
275+
# evaluator when EXLA is absent or version-mismatched (e.g. EXLA 0.10
276+
# paired with Nx 0.11 does not export __compile__/4).
277+
defp defn_options do
278+
if Code.ensure_loaded?(EXLA) and function_exported?(EXLA, :__compile__, 4) do
279+
[compiler: EXLA]
280+
else
281+
[]
282+
end
283+
end
284+
275285
@doc """
276286
Classifies an image and returns the full prediction map.
277287

lib/detection.ex

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,11 @@ if ImageVision.ortex_configured?() do
266266

267267
padded = Image.embed!(resized, @input_size, @input_size, x: 0, y: 0)
268268

269-
mean = Nx.tensor([0.485, 0.456, 0.406])
270-
std = Nx.tensor([0.229, 0.224, 0.225])
271-
272269
tensor =
273270
padded
274-
|> Image.to_nx!()
271+
|> Image.to_nx!(backend: Nx.BinaryBackend)
275272
|> Nx.as_type(:f32)
276273
|> Nx.divide(255.0)
277-
|> NxImage.normalize(mean, std)
278274
|> Nx.transpose(axes: [2, 0, 1])
279275
|> Nx.new_axis(0)
280276

@@ -297,8 +293,12 @@ if ImageVision.ortex_configured?() do
297293
original_height = Keyword.fetch!(opts, :original_height)
298294
min_score = Keyword.fetch!(opts, :min_score)
299295

300-
scores = Nx.sigmoid(logits[0])
301-
boxes = pred_boxes[0]
296+
scores =
297+
logits[0]
298+
|> Nx.backend_transfer(Nx.BinaryBackend)
299+
|> Nx.sigmoid()
300+
301+
boxes = pred_boxes[0] |> Nx.backend_transfer(Nx.BinaryBackend)
302302

303303
best_class = Nx.argmax(scores, axis: 1)
304304
best_score = Nx.reduce_max(scores, axes: [1])

lib/segmentation.ex

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ if ImageVision.ortex_configured?() do
180180

181181
orig_w = Image.width(image)
182182
orig_h = Image.height(image)
183-
input_scale = @sam_input_size / max(orig_w, orig_h)
183+
input_scale = max(orig_w, orig_h) / @sam_input_size
184184

185185
{point_coords, point_labels} =
186186
encode_sam_prompt(prompt, orig_w, orig_h, input_scale)
@@ -265,13 +265,13 @@ if ImageVision.ortex_configured?() do
265265
orig_w = Image.width(image)
266266
orig_h = Image.height(image)
267267

268-
{batch, input_h, input_w} = detr_preprocess(image)
268+
{batch, _input_h, _input_w} = detr_preprocess(image)
269269

270270
{logits, _pred_boxes, pred_masks} = Ortex.run(model, batch)
271+
logits = Nx.backend_transfer(logits, Nx.BinaryBackend)
272+
pred_masks = Nx.backend_transfer(pred_masks, Nx.BinaryBackend)
271273

272274
detr_postprocess(logits, pred_masks, id2label,
273-
input_h: input_h,
274-
input_w: input_w,
275275
orig_w: orig_w,
276276
orig_h: orig_h,
277277
min_score: min_score
@@ -413,7 +413,7 @@ if ImageVision.ortex_configured?() do
413413

414414
tensor =
415415
padded
416-
|> Image.to_nx!()
416+
|> Image.to_nx!(backend: Nx.BinaryBackend)
417417
|> Nx.as_type(:f32)
418418
|> Nx.divide(255.0)
419419
|> NxImage.normalize(Nx.tensor(@imagenet_mean), Nx.tensor(@imagenet_std))
@@ -424,16 +424,17 @@ if ImageVision.ortex_configured?() do
424424
end
425425

426426
defp sam_encode(encoder, tensor) do
427-
{image_embed, high_res_feats_0, high_res_feats_1} = Ortex.run(encoder, tensor)
427+
# Model outputs in order: high_res_feats_0, high_res_feats_1, image_embed
428+
{high_res_feats_0, high_res_feats_1, image_embed} = Ortex.run(encoder, tensor)
428429
{image_embed, high_res_feats_0, high_res_feats_1}
429430
end
430431

431432
# Converts a user prompt into SAM point_coords + point_labels tensors.
432433
# Coords are in the 1024×1024 input space.
433-
defp encode_sam_prompt(:auto, orig_w, orig_h, _input_scale) do
434+
defp encode_sam_prompt(:auto, orig_w, orig_h, input_scale) do
434435
cx = orig_w / 2
435436
cy = orig_h / 2
436-
encode_sam_prompt({:point, cx, cy}, orig_w, orig_h, orig_w / @sam_input_size)
437+
encode_sam_prompt({:point, cx, cy}, orig_w, orig_h, input_scale)
437438
end
438439

439440
defp encode_sam_prompt({:point, x, y}, _orig_w, _orig_h, scale) do
@@ -469,7 +470,7 @@ if ImageVision.ortex_configured?() do
469470
mask_input = Nx.broadcast(Nx.tensor(0, type: :f32), {1, 1, 256, 256})
470471
has_mask = Nx.tensor([0], type: :f32)
471472

472-
{masks, iou_predictions, _low_res} =
473+
{masks, iou_predictions} =
473474
Ortex.run(decoder, {
474475
image_embed,
475476
high_res_feats_0,
@@ -480,29 +481,33 @@ if ImageVision.ortex_configured?() do
480481
has_mask
481482
})
482483

483-
{masks, iou_predictions}
484+
{
485+
Nx.backend_transfer(masks, Nx.BinaryBackend),
486+
Nx.backend_transfer(iou_predictions, Nx.BinaryBackend)
487+
}
484488
end
485489

486-
# Converts a SAM output mask tensor [256, 256] (logits) to a
490+
# Converts a SAM output mask tensor {256, 256} (logits) to a
487491
# single-band Vimage at original image dimensions.
488492
defp sam_mask_to_image(mask_tensor, orig_w, orig_h, resized_w, resized_h) do
489-
# Up to the padded 1024×1024 input, then crop to the valid region,
490-
# then resize to original dimensions.
491-
upscaled =
492-
mask_tensor
493-
|> Nx.reshape({1024, 1024, 1})
494-
|> Nx.slice([0, 0, 0], [resized_h, resized_w, 1])
493+
# The 256×256 mask covers the full 1024×1024 padded input space.
494+
# Crop to the valid region, which is resized_w×resized_h in input
495+
# space, scaled to mask space by the factor 256/1024 = 1/4.
496+
valid_h = round(resized_h * 256 / @sam_input_size)
497+
valid_w = round(resized_w * 256 / @sam_input_size)
495498

496-
# Threshold at 0 (logits > 0 → object)
497499
binary =
498-
upscaled
500+
mask_tensor
501+
|> Nx.slice([0, 0], [valid_h, valid_w])
499502
|> Nx.greater(0)
500503
|> Nx.multiply(255)
501504
|> Nx.as_type(:u8)
505+
|> Nx.transpose()
506+
|> Nx.new_axis(2)
502507

503508
binary
504509
|> Image.from_nx!()
505-
|> Image.resize!(orig_w / resized_w, vertical_scale: orig_h / resized_h)
510+
|> Image.resize!(orig_w / valid_w, vertical_scale: orig_h / valid_h)
506511
end
507512

508513
# --- Private: DETR-panoptic pre/post --------------------------------
@@ -525,14 +530,15 @@ if ImageVision.ortex_configured?() do
525530

526531
tensor =
527532
resized
528-
|> Image.to_nx!()
533+
|> Image.to_nx!(backend: Nx.BinaryBackend)
529534
|> Nx.as_type(:f32)
530535
|> Nx.divide(255.0)
531536
|> NxImage.normalize(Nx.tensor(@imagenet_mean), Nx.tensor(@imagenet_std))
532537
|> Nx.transpose(axes: [2, 0, 1])
533538
|> Nx.new_axis(0)
534539

535-
{tensor, input_h, input_w}
540+
pixel_mask = Nx.broadcast(Nx.tensor(1, type: :s64), {1, 64, 64})
541+
{{tensor, pixel_mask}, input_h, input_w}
536542
end
537543

538544
# Loads id2label from config.json; cached in :persistent_term.
@@ -556,8 +562,6 @@ if ImageVision.ortex_configured?() do
556562
# Converts raw DETR-panoptic outputs into a list of segments.
557563
defp detr_postprocess(logits, pred_masks, id2label, opts) do
558564
min_score = Keyword.fetch!(opts, :min_score)
559-
input_h = Keyword.fetch!(opts, :input_h)
560-
input_w = Keyword.fetch!(opts, :input_w)
561565
orig_w = Keyword.fetch!(opts, :orig_w)
562566
orig_h = Keyword.fetch!(opts, :orig_h)
563567

@@ -578,9 +582,8 @@ if ImageVision.ortex_configured?() do
578582
class_list = Nx.to_flat_list(best_class)
579583
score_list = Nx.to_flat_list(best_score)
580584

581-
# pred_masks: [1, 100, H/4, W/4]
582-
mask_h = div(input_h, 4)
583-
mask_w = div(input_w, 4)
585+
# pred_masks: [1, queries, H/4, W/4]
586+
{1, _num_queries, mask_h, mask_w} = Nx.shape(pred_masks)
584587

585588
Enum.zip(class_list, score_list)
586589
|> Enum.with_index()
@@ -591,11 +594,12 @@ if ImageVision.ortex_configured?() do
591594

592595
mask =
593596
mask_tensor
594-
|> Nx.reshape({mask_h, mask_w, 1})
595597
|> Nx.sigmoid()
596598
|> Nx.greater(0.5)
597599
|> Nx.multiply(255)
598600
|> Nx.as_type(:u8)
601+
|> Nx.transpose()
602+
|> Nx.new_axis(2)
599603
|> Image.from_nx!()
600604
|> Image.resize!(orig_w / mask_w, vertical_scale: orig_h / mask_h)
601605

mix.exs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ defmodule ImageVision.MixProject do
7575
{:nx, "~> 0.11", optional: true, override: true},
7676
{:nx_image, "~> 0.1", optional: true},
7777
{:bumblebee, "~> 0.6", optional: true},
78+
{:exla, "~> 0.11", optional: true},
7879

7980
# --- Tooling ---
8081
{:ex_doc, "~> 0.18", only: [:release, :dev, :docs]},

mix.lock

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
"elixir_make": {:hex, :elixir_make, "0.9.0", "6484b3cd8c0cee58f09f05ecaf1a140a8c97670671a6a0e7ab4dc326c3109726", [:mix], [], "hexpm", "db23d4fd8b757462ad02f8aa73431a426fe6671c80b200d9710caf3d1dd0ffdb"},
1313
"erlex": {:hex, :erlex, "0.2.8", "cd8116f20f3c0afe376d1e8d1f0ae2452337729f68be016ea544a72f767d9c12", [:mix], [], "hexpm", "9d66ff9fedf69e49dc3fd12831e12a8a37b76f8651dd21cd45fcf5561a8a7590"},
1414
"ex_doc": {:hex, :ex_doc, "0.40.1", "67542e4b6dde74811cfd580e2c0149b78010fd13001fda7cfeb2b2c2ffb1344d", [:mix], [{:earmark_parser, "~> 1.4.44", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "bcef0e2d360d93ac19f01a85d58f91752d930c0a30e2681145feea6bd3516e00"},
15-
"exla": {:hex, :exla, "0.10.0", "93e7d75a774fbc06ce05b96de20c4b01bda413b315238cb3c727c09a05d2bc3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1.0", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.10.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.9.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "16fffdb64667d7f0a3bc683fdcd2792b143a9b345e4b1f1d5cd50330c63d8119"},
15+
"exla": {:hex, :exla, "0.11.0", "1428de9edcb297480a64611d3a72fcefe13c93c115bba6d38e910583c37e38c8", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:fine, "~> 0.1", [hex: :fine, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.11.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.10.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "1067207c802bd6f28cded6a2664979ee2e25dddce95cb84be3f0a3ebfbab2c74"},
1616
"finch": {:hex, :finch, "0.21.0", "b1c3b2d48af02d0c66d2a9ebfb5622be5c5ecd62937cf79a88a7f98d48a8290c", [:mix], [{:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.6.2 or ~> 1.7", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.1", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "87dc6e169794cb2570f75841a19da99cfde834249568f2a5b121b809588a4377"},
17-
"fine": {:hex, :fine, "0.1.5", "54880d13ab2c57884a105502d4fdce041f5852fe19bcfdcbcd7327225b1a6d5a", [:mix], [], "hexpm", "39f8f3f48a21e053c483f362b7b6a3bb38fdd987b31debc4d4e7a77fe8564335"},
17+
"fine": {:hex, :fine, "0.1.6", "4bf7151493443c454aac9f2fa2f34f5fefd0346a83fb5586a016c4a135c63247", [:mix], [], "hexpm", "5638eb4495488e885ebec167fa57973e5c35e1a50c344eb7666c90ec1c4e3b12"},
1818
"hpax": {:hex, :hpax, "1.0.3", "ed67ef51ad4df91e75cc6a1494f851850c0bd98ebc0be6e81b026e765ee535aa", [:mix], [], "hexpm", "8eab6e1cfa8d5918c2ce4ba43588e894af35dbd8e91e6e55c817bca5847df34a"},
1919
"image": {:hex, :image, "0.65.0", "44908233a1a0dcdbb6ae873ec09fd9ae533d1840d300d8b0b1b186d586b935e6", [:mix], [{:color, "~> 0.4", [hex: :color, repo: "hexpm", optional: false]}, {:evision, "~> 0.1.33 or ~> 0.2", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "0.11.0", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.13", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.11.0", [hex: :nx, repo: "hexpm", optional: true]}, {:nx_image, "~> 0.1", [hex: :nx_image, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.1 or ~> 3.2 or ~> 4.0", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:req, "~> 0.4", [hex: :req, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:scholar, "~> 0.3", [hex: :scholar, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.33", [hex: :vix, repo: "hexpm", optional: false]}, {:xav, "~> 0.10", [hex: :xav, repo: "hexpm", optional: true]}], "hexpm", "d2060e08d0f42564f49de1ea97a82a5d237f9ac91edb141dece51f1238dd8b4a"},
2020
"jason": {:hex, :jason, "1.4.4", "b9226785a9aa77b6857ca22832cffa5d5011a667207eb2a0ad56adb5db443b8a", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "c5eb0cab91f094599f94d55bc63409236a8ec69a21a67814529e8d5f6cc90b3b"},
@@ -47,5 +47,5 @@
4747
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
4848
"unzip": {:hex, :unzip, "0.12.0", "beed92238724732418b41eba77dcb7f51e235b707406c05b1732a3052d1c0f36", [:mix], [], "hexpm", "95655b72db368e5a84951f0bed586ac053b55ee3815fd96062fce10ce4fc998d"},
4949
"vix": {:hex, :vix, "0.38.0", "77529ee4f6ced339c3d5f90a9eacf306f5b7109d3d1b5e3ef391a984ad404f75", [:make, :mix], [{:cc_precompiler, "~> 0.1.4 or ~> 0.2", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.3 or ~> 0.8", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}], "hexpm", "dca58f654922fa678d5df8e028317483d9c0f8acb2e2714076a8468695687aa7"},
50-
"xla": {:hex, :xla, "0.9.1", "cca0040ff94902764007a118871bfc667f1a0085d4a5074533a47d6b58bec61e", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "eb5e443ae5391b1953f253e051f2307bea183b59acee138053a9300779930daf"},
50+
"xla": {:hex, :xla, "0.10.0", "41121e9f011456242d3a79b9289910ce43419be0b0e7ebe67cc1292c6b3f232f", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "f57d91aea6e661b52bf12239316c598679e9170628122bbd941235f040122bc6"},
5151
}

test/classification_test.exs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ defmodule Image.ClassificationTest do
22
use ExUnit.Case, async: false
33

44
@moduletag :ml
5+
@moduletag :classification
56

67
@images Path.join(__DIR__, "support/images")
78

89
# Start the classifier serving once for the whole suite. Loading
910
# ConvNeXt-tiny-224 takes several seconds, so we do it here rather
1011
# than per-test.
1112
setup_all do
13+
Application.ensure_all_started(:exla)
1214
spec = Image.Classification.classifier()
1315
start_supervised!(spec)
1416
:ok
@@ -29,7 +31,7 @@ defmodule Image.ClassificationTest do
2931
describe "labels/2" do
3032
test "classifies a Cavalier King Charles Spaniel as a Blenheim spaniel" do
3133
image = Image.open!(Path.join(@images, "puppy.webp"))
32-
labels = Image.Classification.labels(image)
34+
labels = Image.Classification.labels(image, min_score: 0.1)
3335
assert "Blenheim spaniel" in labels
3436
end
3537

test/detection_test.exs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ defmodule Image.DetectionTest do
22
use ExUnit.Case, async: false
33

44
@moduletag :ml
5+
@moduletag :ortex
56

67
@corpus_dir Path.join(__DIR__, "support/images/segmentation")
78
@corpus Path.join(@corpus_dir, "corpus.json")
@@ -41,7 +42,7 @@ defmodule Image.DetectionTest do
4142
for entry <- @corpus do
4243
image = open_image(entry)
4344
expected = entry["coco_class"]
44-
detections = Image.Detection.detect(image)
45+
detections = Image.Detection.detect(image, min_score: 0.1)
4546
labels = Enum.map(detections, & &1.label)
4647

4748
assert Enum.any?(labels, &(&1 == expected)),
@@ -55,7 +56,7 @@ defmodule Image.DetectionTest do
5556
expected = entry["coco_class"]
5657
[gt_x, gt_y, gt_w, gt_h] = entry["prompt_box"]
5758

58-
detections = Image.Detection.detect(image)
59+
detections = Image.Detection.detect(image, min_score: 0.1)
5960

6061
best_iou =
6162
detections

0 commit comments

Comments
 (0)