Skip to content

Commit 1890ac6

Browse files
Merge branch 'main' into add-support-for-erfinv
2 parents 9a0c0af + 6b5283c commit 1890ac6

113 files changed

Lines changed: 3040 additions & 437 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/build-qnn-sdk.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ build_qnn_backend() {
1818
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"
1919

2020
parallelism=$(( $(nproc) - 1 ))
21-
bash backends/qualcomm/scripts/build.sh --skip_linux_android --skip_linux_embedded --job_number ${parallelism} --release
21+
bash backends/qualcomm/scripts/build.sh --skip_linux_android --job_number ${parallelism} --release
2222
}
2323

2424
set_up_aot() {

.ci/scripts/export_model_artifact.sh

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Arguments:
2222
- mistralai/Voxtral-Mini-4B-Realtime-2602
2323
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2424
- google/gemma-3-4b-it
25+
- nvidia/diar_streaming_sortformer_4spk-v2
2526
- nvidia/parakeet-tdt
2627
2728
quant_name Quantization type (optional, default: non-quantized)
@@ -45,6 +46,7 @@ Examples:
4546
export_model_artifact.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "quantized-int4-metal"
4647
export_model_artifact.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "non-quantized" "." "vr-streaming"
4748
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
49+
export_model_artifact.sh cuda-windows "nvidia/diar_streaming_sortformer_4spk-v2" "non-quantized" "./output"
4850
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
4951
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
5052
export_model_artifact.sh xnnpack "nvidia/parakeet-tdt" "quantized-8da4w" "./output"
@@ -157,6 +159,14 @@ case "$HF_MODEL" in
157159
PREPROCESSOR_FEATURE_SIZE=""
158160
PREPROCESSOR_OUTPUT=""
159161
;;
162+
nvidia/diar_streaming_sortformer_4spk-v2)
163+
MODEL_NAME="sortformer"
164+
TASK=""
165+
MAX_SEQ_LEN=""
166+
EXTRA_PIP=""
167+
PREPROCESSOR_FEATURE_SIZE=""
168+
PREPROCESSOR_OUTPUT=""
169+
;;
160170
mistralai/Voxtral-Mini-4B-Realtime-2602)
161171
MODEL_NAME="voxtral_realtime"
162172
TASK=""
@@ -167,7 +177,7 @@ case "$HF_MODEL" in
167177
;;
168178
*)
169179
echo "Error: Unsupported model '$HF_MODEL'"
170-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
180+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt"
171181
exit 1
172182
;;
173183
esac
@@ -247,6 +257,42 @@ if [ "$MODEL_NAME" = "parakeet" ]; then
247257
exit 0
248258
fi
249259

260+
# Sortformer uses a custom export script
261+
if [ "$MODEL_NAME" = "sortformer" ]; then
262+
if [ "$QUANT_NAME" != "non-quantized" ]; then
263+
echo "Error: Sortformer currently supports only non-quantized export"
264+
exit 1
265+
fi
266+
267+
pip install -r examples/models/sortformer/install_requirements.txt
268+
269+
SORTFORMER_BACKEND="$DEVICE"
270+
if [ "$DEVICE" = "cuda-windows" ]; then
271+
SORTFORMER_BACKEND="cuda-windows"
272+
elif [ "$DEVICE" = "cuda" ]; then
273+
SORTFORMER_BACKEND="cuda"
274+
elif [ "$DEVICE" = "xnnpack" ]; then
275+
SORTFORMER_BACKEND="xnnpack"
276+
else
277+
SORTFORMER_BACKEND="portable"
278+
fi
279+
280+
python -m executorch.examples.models.sortformer.export_sortformer \
281+
--hf-model "${HF_MODEL}" \
282+
--backend "${SORTFORMER_BACKEND}" \
283+
--output-dir "${OUTPUT_DIR}"
284+
285+
test -f "${OUTPUT_DIR}/sortformer.pte"
286+
mv "${OUTPUT_DIR}/sortformer.pte" "${OUTPUT_DIR}/model.pte"
287+
# CUDA saves named data to separate .ptd file, XNNPACK/portable do not.
288+
if [ "$DEVICE" = "cuda" ] || [ "$DEVICE" = "cuda-windows" ]; then
289+
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
290+
fi
291+
ls -al "${OUTPUT_DIR}"
292+
echo "::endgroup::"
293+
exit 0
294+
fi
295+
250296
# Voxtral Realtime uses a custom export script
251297
if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
252298
pip install safetensors huggingface_hub

.ci/scripts/test_model_e2e.sh

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Arguments:
1919
hf_model HuggingFace model ID (required)
2020
Supported models:
2121
- mistralai/Voxtral-Mini-3B-2507
22+
- nvidia/diar_streaming_sortformer_4spk-v2
2223
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2324
- google/gemma-3-4b-it
2425
- Qwen/Qwen3-0.6B
@@ -44,6 +45,7 @@ Arguments:
4445
Examples:
4546
test_model_e2e.sh metal "openai/whisper-small" "non-quantized"
4647
test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output"
48+
test_model_e2e.sh cuda "nvidia/diar_streaming_sortformer_4spk-v2" "non-quantized" "./model_output"
4749
test_model_e2e.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./model_output"
4850
test_model_e2e.sh xnnpack "nvidia/parakeet-tdt" "quantized-8da4w" "./model_output"
4951
test_model_e2e.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "non-quantized" "." "vr-streaming"
@@ -176,6 +178,18 @@ case "$HF_MODEL" in
176178
AUDIO_FILE="test_audio.wav"
177179
IMAGE_PATH=""
178180
;;
181+
nvidia/diar_streaming_sortformer_4spk-v2)
182+
MODEL_NAME="sortformer"
183+
RUNNER_TARGET="sortformer_runner"
184+
RUNNER_PATH="sortformer"
185+
EXPECTED_OUTPUT="Speaker 1"
186+
PREPROCESSOR=""
187+
TOKENIZER_URL=""
188+
TOKENIZER_FILE=""
189+
AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
190+
AUDIO_FILE="poem.wav"
191+
IMAGE_PATH=""
192+
;;
179193
mistralai/Voxtral-Mini-4B-Realtime-2602)
180194
MODEL_NAME="voxtral_realtime"
181195
RUNNER_TARGET="voxtral_realtime_runner"
@@ -190,7 +204,7 @@ case "$HF_MODEL" in
190204
;;
191205
*)
192206
echo "Error: Unsupported model '$HF_MODEL'"
193-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
207+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
194208
exit 1
195209
;;
196210
esac
@@ -203,8 +217,8 @@ echo "::endgroup::"
203217
echo "::group::Prepare $MODEL_NAME Artifacts"
204218

205219

206-
# Download tokenizer files (skip for parakeet and voxtral_realtime which bundle tokenizer in export)
207-
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ]; then
220+
# Download tokenizer files (skip for models that bundle tokenizer in export or do not use one)
221+
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ]; then
208222
if [ "$TOKENIZER_FILE" != "" ]; then
209223
curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE
210224
else
@@ -296,6 +310,12 @@ EOF
296310
RUNNER_ARGS="$RUNNER_ARGS --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd"
297311
fi
298312
;;
313+
sortformer)
314+
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --audio_path ${MODEL_DIR}/$AUDIO_FILE"
315+
if [ "$DEVICE" = "cuda" ]; then
316+
RUNNER_ARGS="$RUNNER_ARGS --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd"
317+
fi
318+
;;
299319
voxtral_realtime)
300320
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
301321
# Add CUDA data path if present

.ci/scripts/test_model_e2e_windows.ps1

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ switch ($HfModel) {
6464
$audioUrl = "https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav"
6565
$audioFile = "test_audio.wav"
6666
}
67+
"nvidia/diar_streaming_sortformer_4spk-v2" {
68+
$runnerTarget = "sortformer_runner"
69+
$runnerPath = "sortformer"
70+
$runnerPreset = "sortformer-cuda"
71+
$expectedOutput = "Speaker 1"
72+
$preprocessor = ""
73+
$tokenizerUrl = ""
74+
$tokenizerFile = ""
75+
$audioUrl = "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
76+
$audioFile = "poem.wav"
77+
}
6778
"mistralai/Voxtral-Mini-4B-Realtime-2602" {
6879
$runnerTarget = "voxtral_realtime_runner"
6980
$runnerPath = "voxtral_realtime"
@@ -76,7 +87,7 @@ switch ($HfModel) {
7687
$audioFile = "poem.wav"
7788
}
7889
default {
79-
throw "Unsupported model '$HfModel'. Supported: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/parakeet-tdt"
90+
throw "Unsupported model '$HfModel'. Supported: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt"
8091
}
8192
}
8293

@@ -182,6 +193,13 @@ try {
182193
"--data_path", $cudaBlob
183194
)
184195
}
196+
"nvidia/diar_streaming_sortformer_4spk-v2" {
197+
$runnerArgs = @(
198+
"--model_path", $modelPte,
199+
"--audio_path", (Join-Path -Path $resolvedModelDir -ChildPath $audioFile),
200+
"--data_path", $cudaBlob
201+
)
202+
}
185203
"mistralai/Voxtral-Mini-4B-Realtime-2602" {
186204
$runnerArgs += @(
187205
"--temperature", "0",

.github/workflows/add-unanswered-to-project.yml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,25 @@ jobs:
4343
"ethansfng", "ThomasJannaud", "nirvanagth", "marcinkwiatkowski", "3l1", "omerjerk", "nitish2112", "yipjustin",
4444
"ejnguyen", "andrewor14", "phaiting", "mgiordy", "LeeOHzzZ", "adicatana", "Polyomino", "ezrilow", "navsud",
4545
"michaelmaitland", "RahulC7", "seyeong-han", "thdusdl1219", "jaejunku", "felixweilbach", "apullin", "trviv", "junluan01",
46-
"mvartani-meta", "abeakkas", "elpdumont", "corporateshark", "bdemirb", "GeorgeTzoupis", "AdithyaReddy9", "YifanShenSZ",
47-
"RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat","azad-meta", "junpi", "pytorchbot", "pytorchmergebot", "pytorchupdatebot",
46+
"mvartani-meta", "abeakkas", "elpdumont", "corporateshark", "bdemirb", "GeorgeTzoupis", "AdithyaReddy9", "drinkmorewaterr",
47+
"YifanShenSZ", "RdoubleA", "Olivia-liu", "Abhi-hpp", "Vysarat","azad-meta", "junpi", "pytorchbot", "pytorchmergebot", "pytorchupdatebot",
4848
"facebook-github-bot", "app/dependabot", "Erik-Lundell", "zingo", "AdrianLundell", "oscarandersson8218", "per",
4949
"Sebastian-Larsson", "SaoirseARM", "robell", "mansnils", "martinlsm", "freddan80", "YufengShi-dudu", "tom-arm", "perheld",
5050
"Jerry-Ge", "gggekov", "fumchin", "wwwind", "benkli01", "Tessil", "maddun01", "Michiel-Olieslagers", "armwaheed", "agrima1304",
5151
"emmakujala", "annietllnd", "MatthiasHertel80", "AlexTawseArm", "jmahbs", "morgolock", "Christoffer-JL", "ArmRyan", "xingguo01",
52-
"tgonzalezorlandoarm", "chizkiyahu", "sarah-blades", "haowhsu-quic", "shewu-quic", "winskuo-quic", "chunit-quic", "DannyYuyang-quic",
53-
"chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97", "billmguo", "StrycekSimon", "jirioc",
54-
"robert-kalmar", "skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav", "neuropilot-captain", "dijopaul", "cad-rlc",
55-
"cad-audio", "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa", "anzr299", "Jiseong-oh", "alexdean08",
52+
"tgonzalezorlandoarm", "chizkiyahu", "sarah-blades", "itsMarco-G", "usamahz", "haowhsu-quic", "shewu-quic", "winskuo-quic",
53+
"chunit-quic", "DannyYuyang-quic", "chuntl", "thchenqti", "jethroqti", "chenweng-quic", "cymbalrush", "DenisVieriu97", "billmguo",
54+
"StrycekSimon", "jirioc", "robert-kalmar", "skywall", "MartinPavella", "roman-janik-nxp", "novak-vaclav", "neuropilot-captain",
55+
"dijopaul", "cad-rlc", "cad-audio", "ynimmaga", "daniil-lyakhov", "emmanuel-ferdman", "cavusmustafa", "anzr299", "suryasidd",
56+
"Jiseong-oh", "alexdean08",
5657
// explicitly include the dependabot bot login seen in PRs
5758
"dependabot[bot]"
5859
]);
5960
6061
// List of organization logins (lowercased) to exclude members of
6162
const excludedOrgs = new Set([
62-
"meta", "facebook", "pytorch", "arm", "apple", "qualcomm", "nxp", "mediatek", "cadence", "intel", "samsung"
63+
"meta", "facebook", "pytorch", "arm", "apple", "qualcomm", "nxp", "mediatek", "cadence", "intel", "samsung",
64+
"@meta", "@facebook", "@pytorch", "@arm", "@apple", "@qualcomm", "@nxp", "@mediatek", "@cadence", "@intel", "@samsung"
6365
]);
6466
6567
// Labels on PRs to exclude from being added to the project

.github/workflows/cuda-windows.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ jobs:
4141
- model_repo: "nvidia"
4242
model_name: "parakeet-tdt"
4343
quant: "quantized-int4-weight-only"
44+
- model_repo: "nvidia"
45+
model_name: "diar_streaming_sortformer_4spk-v2"
46+
quant: "non-quantized"
4447
- model_repo: "mistralai"
4548
model_name: "Voxtral-Mini-4B-Realtime-2602"
4649
quant: "quantized-int4-tile-packed"
@@ -113,6 +116,9 @@ jobs:
113116
- model_repo: "nvidia"
114117
model_name: "parakeet-tdt"
115118
quant: "quantized-int4-weight-only"
119+
- model_repo: "nvidia"
120+
model_name: "diar_streaming_sortformer_4spk-v2"
121+
quant: "non-quantized"
116122
- model_repo: "mistralai"
117123
model_name: "Voxtral-Mini-4B-Realtime-2602"
118124
quant: "quantized-int4-tile-packed"

.github/workflows/cuda.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ jobs:
139139
name: "Voxtral-Mini-3B-2507"
140140
- repo: "mistralai"
141141
name: "Voxtral-Mini-4B-Realtime-2602"
142+
- repo: "nvidia"
143+
name: "diar_streaming_sortformer_4spk-v2"
142144
- repo: "openai"
143145
name: "whisper-small"
144146
- repo: "openai"
@@ -168,6 +170,15 @@ jobs:
168170
repo: "mistralai"
169171
name: "Voxtral-Mini-4B-Realtime-2602"
170172
quant: "quantized-int4-weight-only"
173+
# Sortformer currently supports only non-quantized export
174+
- model:
175+
repo: "nvidia"
176+
name: "diar_streaming_sortformer_4spk-v2"
177+
quant: "quantized-int4-tile-packed"
178+
- model:
179+
repo: "nvidia"
180+
name: "diar_streaming_sortformer_4spk-v2"
181+
quant: "quantized-int4-weight-only"
171182
with:
172183
timeout: 90
173184
secrets-env: EXECUTORCH_HF_TOKEN
@@ -214,6 +225,8 @@ jobs:
214225
name: "Voxtral-Mini-3B-2507"
215226
- repo: "mistralai"
216227
name: "Voxtral-Mini-4B-Realtime-2602"
228+
- repo: "nvidia"
229+
name: "diar_streaming_sortformer_4spk-v2"
217230
- repo: "openai"
218231
name: "whisper-small"
219232
- repo: "openai"
@@ -241,6 +254,15 @@ jobs:
241254
repo: "mistralai"
242255
name: "Voxtral-Mini-4B-Realtime-2602"
243256
quant: "quantized-int4-weight-only"
257+
# Sortformer currently supports only non-quantized export
258+
- model:
259+
repo: "nvidia"
260+
name: "diar_streaming_sortformer_4spk-v2"
261+
quant: "quantized-int4-tile-packed"
262+
- model:
263+
repo: "nvidia"
264+
name: "diar_streaming_sortformer_4spk-v2"
265+
quant: "quantized-int4-weight-only"
244266
with:
245267
timeout: 90
246268
runner: linux.g5.4xlarge.nvidia.gpu

Makefile

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal)
1919
# - whisper: Speech recognition model (CPU, CUDA, Metal)
2020
# - parakeet: Speech recognition model (CPU, CUDA, Metal)
21-
# - sortformer: Speaker diarization model (CPU)
21+
# - sortformer: Speaker diarization model (CPU, CUDA)
2222
# - silero_vad: Voice activity detection model (CPU)
2323
# - llama: Text generation model (CPU)
2424
# - llava: Vision + language model (CPU)
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -109,6 +109,7 @@ help:
109109
@echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)"
110110
@echo " parakeet-cpu - Build Parakeet runner with CPU backend"
111111
@echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)"
112+
@echo " sortformer-cuda - Build Sortformer runner with CUDA backend"
112113
@echo " sortformer-cpu - Build Sortformer runner with CPU backend"
113114
@echo " silero-vad-cpu - Build Silero VAD runner with CPU backend"
114115
@echo " llama-cuda - Build Llama runner with CUDA backend"
@@ -218,6 +219,15 @@ parakeet-metal:
218219
@echo "✓ Build complete!"
219220
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"
220221

222+
sortformer-cuda:
223+
@echo "==> Building and installing ExecuTorch with CUDA..."
224+
cmake --workflow --preset llm-release-cuda
225+
@echo "==> Building Sortformer runner with CUDA..."
226+
cd examples/models/sortformer && cmake --workflow --preset sortformer-cuda
227+
@echo ""
228+
@echo "✓ Build complete!"
229+
@echo " Binary: cmake-out/examples/models/sortformer/sortformer_runner"
230+
221231
sortformer-cpu:
222232
@echo "==> Building and installing ExecuTorch..."
223233
cmake --workflow --preset llm-release

backends/arm/_passes/replace_inf_and_limit_values_pass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ def call(self, graph_module: torch.fx.GraphModule):
5252

5353
modified = True
5454
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
55-
tensor[tensor == float("inf")] = 255
56-
tensor[tensor == float("-inf")] = -255
57-
setattr(graph_module, buf_name, tensor)
55+
t = torch.nan_to_num(tensor, posinf=255, neginf=-255)
56+
setattr(graph_module, buf_name, t)
5857

5958
for node in graph_module.graph.nodes:
6059
arg_list = list(node.args)

backends/arm/arm_vela.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,7 @@ def run(dir: str) -> bytes:
8888
args.append("--verbose-all")
8989
vela.main(" ".join(args).split(" "))
9090

91-
if any("ethos-u85" in arg for arg in args) or any(
92-
"debug-force-regor" in arg for arg in args
93-
):
94-
np_path = os.path.join(dir, "output", "out_vela.npz")
95-
else:
96-
np_path = os.path.join(dir, "output", "out_sg0_vela.npz")
91+
np_path = os.path.join(dir, "output", "out_vela.npz")
9792

9893
blocks = b""
9994
with np.load(np_path, allow_pickle=False) as data:

0 commit comments

Comments
 (0)