|
3 | 3 | # This source code is licensed under the BSD-style license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
| 6 | +import os |
6 | 7 | from types import SimpleNamespace |
7 | 8 | from typing import cast |
| 9 | +from unittest import mock |
8 | 10 |
|
9 | 11 | import pytest |
10 | 12 |
|
|
14 | 16 | clear_registered_pass_insertions, |
15 | 17 | PassInsertions, |
16 | 18 | ) |
17 | | -from executorch.backends.arm.vgf import backend as vgf_backend, VgfCompileSpec |
| 19 | + |
| 20 | +from executorch.backends.arm.vgf import backend, backend as vgf_backend, VgfCompileSpec |
| 21 | +from executorch.backends.arm.vgf.backend import ( |
| 22 | + _copy_failure_artifacts, |
| 23 | + _format_repro_command, |
| 24 | + _replace_converter_input_path, |
| 25 | + vgf_compile, |
| 26 | +) |
18 | 27 | from executorch.exir.backend.backend_details import PreprocessResult |
19 | 28 | from executorch.exir.pass_base import ExportPass |
20 | 29 | from torch.export.exported_program import ExportedProgram |
@@ -105,3 +114,180 @@ def _raise(*args, **kwargs): |
105 | 114 | assert _registry_state() == original_registry |
106 | 115 | finally: |
107 | 116 | clear_registered_pass_insertions() |
| 117 | + |
| 118 | + |
| 119 | +def test_format_repro_command_quotes_shell_metacharacters(): |
| 120 | + command = [ |
| 121 | + "model-converter", |
| 122 | + "--flag=value with spaces", |
| 123 | + "-i", |
| 124 | + "input file.tosa", |
| 125 | + "-o", |
| 126 | + "output file.vgf", |
| 127 | + ] |
| 128 | + |
| 129 | + formatted = _format_repro_command(command) |
| 130 | + |
| 131 | + assert formatted == ( |
| 132 | + "model-converter " |
| 133 | + "'--flag=value with spaces' " |
| 134 | + "-i " |
| 135 | + "'input file.tosa' " |
| 136 | + "-o " |
| 137 | + "'output file.vgf'" |
| 138 | + ) |
| 139 | + |
| 140 | + |
| 141 | +def test_replace_converter_input_path_replaces_input_after_i(): |
| 142 | + command = [ |
| 143 | + "model-converter", |
| 144 | + "--some-flag", |
| 145 | + "-i", |
| 146 | + "original.tosa", |
| 147 | + "-o", |
| 148 | + "output.vgf", |
| 149 | + ] |
| 150 | + |
| 151 | + replaced = _replace_converter_input_path(command, "preserved.tosa") |
| 152 | + |
| 153 | + assert replaced == [ |
| 154 | + "model-converter", |
| 155 | + "--some-flag", |
| 156 | + "-i", |
| 157 | + "preserved.tosa", |
| 158 | + "-o", |
| 159 | + "output.vgf", |
| 160 | + ] |
| 161 | + assert command[3] == "original.tosa" |
| 162 | + |
| 163 | + |
| 164 | +def test_copy_failure_artifacts_returns_none_without_artifact_path(tmp_path): |
| 165 | + tosa_path = tmp_path / "input.tosa" |
| 166 | + tosa_path.write_bytes(b"tosa bytes") |
| 167 | + |
| 168 | + copied_path = _copy_failure_artifacts( |
| 169 | + str(tosa_path), |
| 170 | + artifact_path=None, |
| 171 | + tag_name="delegate_0", |
| 172 | + ) |
| 173 | + |
| 174 | + assert copied_path is None |
| 175 | + |
| 176 | + |
| 177 | +def test_copy_failure_artifacts_copies_tosa_with_tag_name(tmp_path): |
| 178 | + tosa_path = tmp_path / "input.tosa" |
| 179 | + artifact_path = tmp_path / "artifacts" |
| 180 | + tosa_path.write_bytes(b"tosa bytes") |
| 181 | + |
| 182 | + copied_path = _copy_failure_artifacts( |
| 183 | + str(tosa_path), |
| 184 | + str(artifact_path), |
| 185 | + tag_name="delegate_0", |
| 186 | + ) |
| 187 | + |
| 188 | + assert copied_path == os.path.join( |
| 189 | + str(artifact_path), |
| 190 | + "failed_model_converter_input_delegate_0.tosa", |
| 191 | + ) |
| 192 | + assert os.path.exists(copied_path) |
| 193 | + assert open(copied_path, "rb").read() == b"tosa bytes" |
| 194 | + |
| 195 | + |
| 196 | +def test_copy_failure_artifacts_copies_tosa_without_tag_name(tmp_path): |
| 197 | + tosa_path = tmp_path / "input.tosa" |
| 198 | + artifact_path = tmp_path / "artifacts" |
| 199 | + tosa_path.write_bytes(b"tosa bytes") |
| 200 | + |
| 201 | + copied_path = _copy_failure_artifacts( |
| 202 | + str(tosa_path), |
| 203 | + str(artifact_path), |
| 204 | + tag_name="", |
| 205 | + ) |
| 206 | + |
| 207 | + assert copied_path == os.path.join( |
| 208 | + str(artifact_path), |
| 209 | + "failed_model_converter_input.tosa", |
| 210 | + ) |
| 211 | + assert os.path.exists(copied_path) |
| 212 | + assert open(copied_path, "rb").read() == b"tosa bytes" |
| 213 | + |
| 214 | + |
| 215 | +@mock.patch("executorch.backends.arm.vgf.backend.model_converter_env") |
| 216 | +@mock.patch("executorch.backends.arm.vgf.backend.require_model_converter_binary") |
| 217 | +@mock.patch("executorch.backends.arm.vgf.backend.subprocess.run") |
| 218 | +def test_vgf_compile_failure_includes_repro_command_and_copies_tosa( |
| 219 | + mock_run, |
| 220 | + mock_require_model_converter_binary, |
| 221 | + mock_model_converter_env, |
| 222 | + tmp_path, |
| 223 | +): |
| 224 | + artifact_path = tmp_path / "artifacts" |
| 225 | + |
| 226 | + mock_require_model_converter_binary.return_value = "model-converter" |
| 227 | + mock_model_converter_env.return_value = {"PATH": "/test/bin"} |
| 228 | + mock_run.side_effect = backend.subprocess.CalledProcessError( |
| 229 | + returncode=1, |
| 230 | + cmd=["model-converter"], |
| 231 | + output=b"converter stdout", |
| 232 | + stderr=b"converter stderr", |
| 233 | + ) |
| 234 | + |
| 235 | + with pytest.raises(RuntimeError) as exc_info: |
| 236 | + vgf_compile( |
| 237 | + b"serialized tosa", |
| 238 | + ["--flag=value with spaces"], |
| 239 | + artifact_path=str(artifact_path), |
| 240 | + tag_name="delegate_0", |
| 241 | + ) |
| 242 | + |
| 243 | + copied_tosa_path = os.path.join( |
| 244 | + str(artifact_path), |
| 245 | + "failed_model_converter_input_delegate_0.tosa", |
| 246 | + ) |
| 247 | + |
| 248 | + assert os.path.exists(copied_tosa_path) |
| 249 | + assert open(copied_tosa_path, "rb").read() == b"serialized tosa" |
| 250 | + |
| 251 | + error = str(exc_info.value) |
| 252 | + assert "Vgf compiler failed." in error |
| 253 | + assert "Repro command:" in error |
| 254 | + assert "model-converter '--flag=value with spaces' -i" in error |
| 255 | + assert copied_tosa_path in error |
| 256 | + assert " -o " in error |
| 257 | + assert "Stderr:\nconverter stderr" in error |
| 258 | + assert "Stdout:\nconverter stdout" in error |
| 259 | + |
| 260 | + |
| 261 | +@mock.patch("executorch.backends.arm.vgf.backend.model_converter_env") |
| 262 | +@mock.patch("executorch.backends.arm.vgf.backend.require_model_converter_binary") |
| 263 | +@mock.patch("executorch.backends.arm.vgf.backend.subprocess.run") |
| 264 | +def test_vgf_compile_failure_includes_temp_repro_command_without_artifact_path( |
| 265 | + mock_run, |
| 266 | + mock_require_model_converter_binary, |
| 267 | + mock_model_converter_env, |
| 268 | +): |
| 269 | + mock_require_model_converter_binary.return_value = "model-converter" |
| 270 | + mock_model_converter_env.return_value = {"PATH": "/test/bin"} |
| 271 | + mock_run.side_effect = backend.subprocess.CalledProcessError( |
| 272 | + returncode=1, |
| 273 | + cmd=["model-converter"], |
| 274 | + output=b"converter stdout", |
| 275 | + stderr=b"converter stderr", |
| 276 | + ) |
| 277 | + |
| 278 | + with pytest.raises(RuntimeError) as exc_info: |
| 279 | + vgf_compile( |
| 280 | + b"serialized tosa", |
| 281 | + ["--some-flag"], |
| 282 | + artifact_path=None, |
| 283 | + tag_name="delegate_0", |
| 284 | + ) |
| 285 | + |
| 286 | + error = str(exc_info.value) |
| 287 | + assert "Vgf compiler failed." in error |
| 288 | + assert "Repro command:" in error |
| 289 | + assert "model-converter --some-flag -i" in error |
| 290 | + assert "output_delegate_0.tosa.vgf" in error |
| 291 | + assert "failed_model_converter_input_delegate_0.tosa" not in error |
| 292 | + assert "Stderr:\nconverter stderr" in error |
| 293 | + assert "Stdout:\nconverter stdout" in error |
0 commit comments