Skip to content

Commit f7bc754

Browse files
committed
Use EXLA as default Nx backend in test env to speed up ML tests
1 parent 26f474e commit f7bc754

2 files changed

Lines changed: 30 additions & 16 deletions

File tree

Dockerfile.ortex-precompiled

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,14 @@ end
104104
LIBEOF
105105
EOF
106106

107-
# Fetch the tinymodel.onnx fixture from the Ortex repo. This file
108-
# defines a single Float32 input of shape [nil, 100] and a single
109-
# Float32 output of shape [nil, 200] — small enough to load and run
110-
# in milliseconds.
107+
# Fetch the tinymodel.onnx fixture from the Ortex repo. This is the
108+
# tiny model Ortex's own doctests use. Its current schema:
109+
#
110+
# inputs: x (Int32, [-1, 100])
111+
# y (Float32, [-1, 100])
112+
# outputs: output1, output2, output3 (Float32, [-1, 10])
113+
#
114+
# Small enough to load and run in milliseconds.
111115
RUN mkdir -p models \
112116
&& curl -fsSL -o models/tinymodel.onnx \
113117
"https://raw.githubusercontent.com/${ORTEX_REPO}/${ORTEX_MODEL_REF}/models/tinymodel.onnx" \
@@ -128,22 +132,27 @@ RUN echo "=== Ortex priv/native contents ===" \
128132
&& ldd deps/ortex/priv/native/*.so 2>&1 || true
129133

130134
# End-to-end smoke test: load the NIF, load tinymodel.onnx, run a
131-
# forward pass with zeroed input. Failure here means the NIF loaded
135+
# forward pass with zeroed inputs. Failure here means the NIF loaded
132136
# but onnxruntime is broken in some other way.
137+
#
138+
# tinymodel.onnx takes a tuple of {x: int32[-1, 100], y: float32[-1, 100]}
139+
# and returns a tuple of three float32[-1, 10] tensors. Each output's
140+
# shape with batch=1 is {1, 10}.
133141
RUN mix run -e ' \
134142
IO.puts("--- Ortex NIF smoke test ---"); \
135143
exports = Ortex.Native.module_info(:exports); \
136144
IO.puts("Ortex.Native exports #{length(exports)} functions"); \
137145
model = Ortex.load("./models/tinymodel.onnx"); \
138146
IO.puts("Loaded: #{inspect(model)}"); \
139-
input = Nx.broadcast(0.0, {1, 100}) |> Nx.as_type(:f32); \
140-
{output} = Ortex.run(model, input); \
141-
shape = output |> Nx.backend_transfer() |> Nx.shape(); \
142-
IO.puts("Inference output shape: #{inspect(shape)}"); \
143-
if shape == {1, 200} do \
147+
x = Nx.broadcast(0, {1, 100}) |> Nx.as_type(:s32); \
148+
y = Nx.broadcast(0.0, {1, 100}) |> Nx.as_type(:f32); \
149+
{out1, out2, out3} = Ortex.run(model, {x, y}); \
150+
shapes = Enum.map([out1, out2, out3], &(Nx.backend_transfer(&1) |> Nx.shape())); \
151+
IO.puts("Inference output shapes: #{inspect(shapes)}"); \
152+
if shapes == [{1, 10}, {1, 10}, {1, 10}] do \
144153
IO.puts("PASS: Ortex precompiled NIF works end-to-end."); \
145154
else \
146-
IO.puts("FAIL: unexpected output shape #{inspect(shape)}"); \
155+
IO.puts("FAIL: unexpected output shapes #{inspect(shapes)}"); \
147156
System.halt(1); \
148157
end \
149158
'

config/test.exs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,13 @@ import Config
33
config :logger,
44
level: :warning
55

6-
# Route all Nx.Defn computations (including Bumblebee featurizer
7-
# preprocessing) through EXLA. This includes Apple Silicon — EXLA's
8-
# XLA CPU path uses NEON/AMX and is significantly faster than the
9-
# pure-Elixir Nx.Defn.Evaluator for both preprocessing and inference.
10-
config :nx, :default_defn_options, compiler: EXLA
6+
# Route all Nx tensor allocations and Nx.Defn computations through
7+
# EXLA. Without `default_backend`, only `defn`-compiled inference
8+
# uses EXLA — the surrounding tensor work (image preprocessing,
9+
# output reshaping, similarity dot products, etc.) falls back to
10+
# the pure-Elixir Nx.BinaryBackend, which is orders of magnitude
11+
# slower for image-sized tensors. This affected ML test runtime
12+
# significantly before we set it.
13+
config :nx,
14+
default_backend: EXLA.Backend,
15+
default_defn_options: [compiler: EXLA]

0 commit comments

Comments
 (0)