Skip to content

Commit b4439bf

Browse files
committed
Add mix image_vision.download task to pre-fetch default models
1 parent 0e7ac30 commit b4439bf

1 file changed

Lines changed: 184 additions & 0 deletions

File tree

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
defmodule Mix.Tasks.ImageVision.Download do
2+
@shortdoc "Pre-downloads ImageVision default models into the local cache"
3+
4+
@moduledoc """
5+
Pre-downloads the default models used by `Image.Classification`,
6+
`Image.Segmentation`, and `Image.Detection` so that first-call
7+
latency is eliminated and the application can run fully offline.
8+
9+
By default, every category's models are fetched. Pass one or more of
10+
`--classify`, `--segment`, `--detect` to limit the scope. Any
11+
category whose optional dependency is not loaded is skipped with a
12+
notice rather than treated as an error.
13+
14+
ONNX weights for segmentation and detection are stored under the
15+
`ImageVision.ModelCache` cache root (see that module for cache
16+
configuration). Bumblebee classifier and embedder weights are
17+
stored under Bumblebee's own HuggingFace cache (controlled by
18+
`BUMBLEBEE_CACHE_DIR` / the standard HF cache env vars).
19+
20+
## Usage
21+
22+
mix image_vision.download
23+
mix image_vision.download --classify
24+
mix image_vision.download --segment --detect
25+
26+
## Configuration
27+
28+
The task respects user overrides for the Bumblebee classifier and
29+
embedder:
30+
31+
config :image_vision, :classifier,
32+
model: {:hf, "facebook/convnext-large-224-22k-1k"},
33+
featurizer: {:hf, "facebook/convnext-large-224-22k-1k"}
34+
35+
Configured values are downloaded; unset values fall back to the
36+
library defaults.
37+
38+
## Options
39+
40+
* `--classify` downloads the classifier and embedder Bumblebee
41+
models. Requires `:bumblebee` and `:nx`.
42+
43+
* `--segment` downloads the SAM 2 and DETR-panoptic ONNX weights.
44+
Requires `:ortex`.
45+
46+
* `--detect` downloads the RT-DETR ONNX weights. Requires
47+
`:ortex`.
48+
49+
"""
50+
51+
use Mix.Task
52+
53+
# Defaults mirror the runtime defaults in
54+
# `Image.Classification`, `Image.Segmentation`, and `Image.Detection`.
55+
# Kept here as the single source of truth for the download task —
56+
# the runtime modules carry their own copy because they are gated
57+
# on optional deps and may not be compiled.
58+
59+
@sam_repo "SharpAI/sam2-hiera-tiny-onnx"
60+
@sam_files ["encoder.onnx", "decoder.onnx"]
61+
62+
@detr_panoptic_repo "Xenova/detr-resnet-50-panoptic"
63+
@detr_panoptic_files ["onnx/model.onnx", "config.json"]
64+
65+
@rtdetr_repo "onnx-community/rtdetr_r50vd"
66+
@rtdetr_files ["onnx/model.onnx"]
67+
68+
@default_classifier_model {:hf, "facebook/convnext-tiny-224"}
69+
@default_classifier_featurizer {:hf, "facebook/convnext-tiny-224"}
70+
@default_embedder_model {:hf, "facebook/dinov2-base"}
71+
@default_embedder_featurizer {:hf, "facebook/dinov2-base"}
72+
73+
@switches [classify: :boolean, segment: :boolean, detect: :boolean]
74+
75+
@impl Mix.Task
76+
def run(argv) do
77+
{options, _args, _invalid} = OptionParser.parse(argv, strict: @switches)
78+
79+
categories =
80+
case Enum.filter(options, fn {_k, v} -> v end) do
81+
[] -> [:classify, :segment, :detect]
82+
selected -> Enum.map(selected, fn {k, _} -> k end)
83+
end
84+
85+
Mix.Task.run("app.config")
86+
Application.ensure_all_started(:req)
87+
88+
Enum.each(categories, &download/1)
89+
90+
Mix.shell().info("")
91+
Mix.shell().info("Done.")
92+
end
93+
94+
defp download(:classify) do
95+
Mix.shell().info("")
96+
Mix.shell().info("[classification]")
97+
98+
if bumblebee_loaded?() do
99+
Application.ensure_all_started(:bumblebee)
100+
101+
classifier = configuration(:classifier)
102+
embedder = configuration(:embedder)
103+
104+
load_bumblebee(:model, Keyword.get(classifier, :model, @default_classifier_model))
105+
106+
load_bumblebee(
107+
:featurizer,
108+
Keyword.get(classifier, :featurizer, @default_classifier_featurizer)
109+
)
110+
111+
load_bumblebee(:model, Keyword.get(embedder, :model, @default_embedder_model))
112+
113+
load_bumblebee(
114+
:featurizer,
115+
Keyword.get(embedder, :featurizer, @default_embedder_featurizer)
116+
)
117+
else
118+
Mix.shell().info(" skipped — :bumblebee dependency not loaded")
119+
end
120+
end
121+
122+
defp download(:segment) do
123+
Mix.shell().info("")
124+
Mix.shell().info("[segmentation]")
125+
126+
if ortex_loaded?() do
127+
Enum.each(@sam_files, &fetch_onnx(@sam_repo, &1))
128+
Enum.each(@detr_panoptic_files, &fetch_onnx(@detr_panoptic_repo, &1))
129+
else
130+
Mix.shell().info(" skipped — :ortex dependency not loaded")
131+
end
132+
end
133+
134+
defp download(:detect) do
135+
Mix.shell().info("")
136+
Mix.shell().info("[detection]")
137+
138+
if ortex_loaded?() do
139+
Enum.each(@rtdetr_files, &fetch_onnx(@rtdetr_repo, &1))
140+
else
141+
Mix.shell().info(" skipped — :ortex dependency not loaded")
142+
end
143+
end
144+
145+
defp fetch_onnx(repo, filename) do
146+
if ImageVision.ModelCache.cached?(repo, filename) do
147+
Mix.shell().info(" cached #{repo}/#{filename}")
148+
else
149+
Mix.shell().info(" download #{repo}/#{filename}")
150+
_path = ImageVision.ModelCache.fetch!(repo, filename)
151+
:ok
152+
end
153+
end
154+
155+
defp load_bumblebee(kind, {:hf, name} = spec) do
156+
Mix.shell().info(" load #{name} (#{kind})")
157+
158+
result =
159+
case kind do
160+
:model -> Bumblebee.load_model(spec)
161+
:featurizer -> Bumblebee.load_featurizer(spec)
162+
end
163+
164+
case result do
165+
{:ok, _loaded} ->
166+
:ok
167+
168+
{:error, reason} ->
169+
Mix.raise("failed to load #{kind} #{inspect(spec)}: #{inspect(reason)}")
170+
end
171+
end
172+
173+
defp configuration(key) do
174+
Application.get_env(:image_vision, key, [])
175+
end
176+
177+
defp bumblebee_loaded? do
178+
Code.ensure_loaded?(Bumblebee)
179+
end
180+
181+
defp ortex_loaded? do
182+
Code.ensure_loaded?(Ortex)
183+
end
184+
end

0 commit comments

Comments
 (0)