|
5 | 5 | # except in compliance with the License. See the license file in the root |
6 | 6 | # directory of this source tree for more details. |
7 | 7 |
|
8 | | - |
| 8 | +import logging |
| 9 | +import os |
9 | 10 | import unittest |
10 | 11 |
|
11 | 12 | from executorch.backends.samsung.serialization.compile_options import ( |
12 | 13 | gen_samsung_backend_compile_spec, |
13 | 14 | ) |
14 | 15 | from executorch.backends.samsung.test.tester import SamsungTester |
| 16 | +from executorch.backends.samsung.test.utils.utils import TestConfig |
15 | 17 | from executorch.examples.models.inception_v4 import InceptionV4Model |
16 | 18 |
|
17 | 19 |
|
| 20 | +def patch_iv4(weight_path: str): |
| 21 | + assert os.path.isfile(weight_path), "Can not found weight path for iv4" |
| 22 | + from safetensors import safe_open |
| 23 | + from timm.models import inception_v4 |
| 24 | + |
| 25 | + def _monkeypatch_get_eager_model(self): |
| 26 | + tensors = {} |
| 27 | + with safe_open(weight_path, framework="pt") as st: |
| 28 | + for k in st.keys(): |
| 29 | + tensors[k] = st.get_tensor(k) |
| 30 | + logging.info("Loading inception_v4 model") |
| 31 | + m = inception_v4(pretrained=True, pretrained_cfg={"state_dict": tensors}) |
| 32 | + logging.info("Loaded inception_v4 model") |
| 33 | + return m |
| 34 | + |
| 35 | + old_func = InceptionV4Model.get_eager_model |
| 36 | + InceptionV4Model.get_eager_model = _monkeypatch_get_eager_model |
| 37 | + return old_func |
| 38 | + |
| 39 | + |
| 40 | +def recover_iv4(old_func): |
| 41 | + InceptionV4Model.get_eager_model = old_func |
| 42 | + |
| 43 | + |
18 | 44 | class TestMilestoneInceptionV4(unittest.TestCase): |
| 45 | + @classmethod |
| 46 | + def setUpClass(cls): |
| 47 | + assert (model_cache_dir := os.getenv("MODEL_CACHE")), "MODEL_CACHE not set!" |
| 48 | + weight_path = os.path.join( |
| 49 | + model_cache_dir, os.path.join(model_cache_dir, "iv4/model.safetensors") |
| 50 | + ) |
| 51 | + cls._old_func = patch_iv4(weight_path) |
| 52 | + |
| 53 | + @classmethod |
| 54 | + def tearDownClass(cls): |
| 55 | + recover_iv4(cls._old_func) |
| 56 | + |
19 | 57 | def test_inception_v4_fp16(self): |
20 | 58 | model = InceptionV4Model().get_eager_model() |
21 | 59 | example_input = InceptionV4Model().get_example_inputs() |
22 | 60 | tester = SamsungTester( |
23 | | - model, example_input, [gen_samsung_backend_compile_spec("E9955")] |
| 61 | + model, example_input, [gen_samsung_backend_compile_spec(TestConfig.chipset)] |
24 | 62 | ) |
25 | 63 | ( |
26 | 64 | tester.export() |
|
0 commit comments