|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | import logging |
7 | | - |
8 | 7 | from dataclasses import dataclass, field |
9 | 8 | from typing import Callable |
10 | 9 |
|
@@ -53,98 +52,98 @@ def __str__(self): |
53 | 52 | return self.name |
54 | 53 |
|
55 | 54 |
|
56 | | -def all_flows() -> dict[str, TestFlow]: |
57 | | - flows = [] |
58 | | - |
59 | | - from executorch.backends.test.suite.flows.portable import PORTABLE_TEST_FLOW |
| 55 | +def _try_import_flows( |
| 56 | + module_path: str, flow_names: list[str], backend_name: str |
| 57 | +) -> list[TestFlow]: |
| 58 | + """ |
| 59 | + Attempt to import test flows from a module. |
60 | 60 |
|
61 | | - flows += [ |
62 | | - PORTABLE_TEST_FLOW, |
63 | | - ] |
| 61 | + Args: |
| 62 | + module_path: The full module path to import from. |
| 63 | + flow_names: List of flow variable names to import from the module. |
| 64 | + backend_name: Human-readable name for logging on failure. |
64 | 65 |
|
| 66 | + Returns: |
| 67 | + List of imported TestFlow objects, or empty list if import fails. |
| 68 | + """ |
65 | 69 | try: |
66 | | - from executorch.backends.test.suite.flows.xnnpack import ( |
67 | | - XNNPACK_DYNAMIC_INT8_PER_CHANNEL_TEST_FLOW, |
68 | | - XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW, |
69 | | - XNNPACK_STATIC_INT8_PER_TENSOR_TEST_FLOW, |
70 | | - XNNPACK_TEST_FLOW, |
71 | | - ) |
72 | | - |
73 | | - flows += [ |
74 | | - XNNPACK_TEST_FLOW, |
75 | | - XNNPACK_DYNAMIC_INT8_PER_CHANNEL_TEST_FLOW, |
76 | | - XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW, |
77 | | - XNNPACK_STATIC_INT8_PER_TENSOR_TEST_FLOW, |
78 | | - ] |
79 | | - except Exception as e: |
80 | | - logger.info(f"Skipping XNNPACK flow registration: {e}") |
| 70 | + import importlib |
81 | 71 |
|
82 | | - try: |
83 | | - from executorch.backends.test.suite.flows.coreml import ( |
84 | | - COREML_STATIC_INT8_TEST_FLOW, |
85 | | - COREML_TEST_FLOW, |
86 | | - ) |
87 | | - |
88 | | - flows += [ |
89 | | - COREML_TEST_FLOW, |
90 | | - COREML_STATIC_INT8_TEST_FLOW, |
91 | | - ] |
| 72 | + module = importlib.import_module(module_path) |
| 73 | + return [getattr(module, name) for name in flow_names] |
92 | 74 | except Exception as e: |
93 | | - logger.info(f"Skipping Core ML flow registration: {e}") |
| 75 | + logger.info(f"Skipping {backend_name} flow registration: {e}") |
| 76 | + return [] |
| 77 | + |
| 78 | + |
| 79 | +# Registry of backend flows to import: (module_path, flow_names, backend_name) |
| 80 | +_FLOW_REGISTRY: list[tuple[str, list[str], str]] = [ |
| 81 | + ( |
| 82 | + "executorch.backends.test.suite.flows.xnnpack", |
| 83 | + [ |
| 84 | + "XNNPACK_TEST_FLOW", |
| 85 | + "XNNPACK_DYNAMIC_INT8_PER_CHANNEL_TEST_FLOW", |
| 86 | + "XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW", |
| 87 | + "XNNPACK_STATIC_INT8_PER_TENSOR_TEST_FLOW", |
| 88 | + ], |
| 89 | + "XNNPACK", |
| 90 | + ), |
| 91 | + ( |
| 92 | + "executorch.backends.test.suite.flows.coreml", |
| 93 | + [ |
| 94 | + "COREML_TEST_FLOW", |
| 95 | + "COREML_STATIC_INT8_TEST_FLOW", |
| 96 | + ], |
| 97 | + "Core ML", |
| 98 | + ), |
| 99 | + ( |
| 100 | + "executorch.backends.test.suite.flows.vulkan", |
| 101 | + [ |
| 102 | + "VULKAN_TEST_FLOW", |
| 103 | + "VULKAN_STATIC_INT8_PER_CHANNEL_TEST_FLOW", |
| 104 | + ], |
| 105 | + "Vulkan", |
| 106 | + ), |
| 107 | + ( |
| 108 | + "executorch.backends.test.suite.flows.qualcomm", |
| 109 | + [ |
| 110 | + "QNN_TEST_FLOW", |
| 111 | + "QNN_16A16W_TEST_FLOW", |
| 112 | + "QNN_16A8W_TEST_FLOW", |
| 113 | + "QNN_16A4W_TEST_FLOW", |
| 114 | + "QNN_16A4W_BLOCK_TEST_FLOW", |
| 115 | + "QNN_8A8W_TEST_FLOW", |
| 116 | + ], |
| 117 | + "QNN", |
| 118 | + ), |
| 119 | + ( |
| 120 | + "executorch.backends.test.suite.flows.arm", |
| 121 | + [ |
| 122 | + "ARM_TOSA_FP_FLOW", |
| 123 | + "ARM_TOSA_INT_FLOW", |
| 124 | + "ARM_ETHOS_U55_FLOW", |
| 125 | + "ARM_ETHOS_U85_FLOW", |
| 126 | + "ARM_VGF_FP_FLOW", |
| 127 | + "ARM_VGF_INT_FLOW", |
| 128 | + ], |
| 129 | + "ARM", |
| 130 | + ), |
| 131 | + ( |
| 132 | + "executorch.backends.test.suite.flows.cuda", |
| 133 | + [ |
| 134 | + "CUDA_TEST_FLOW", |
| 135 | + ], |
| 136 | + "CUDA", |
| 137 | + ), |
| 138 | +] |
94 | 139 |
|
95 | | - try: |
96 | | - from executorch.backends.test.suite.flows.vulkan import ( |
97 | | - VULKAN_STATIC_INT8_PER_CHANNEL_TEST_FLOW, |
98 | | - VULKAN_TEST_FLOW, |
99 | | - ) |
100 | | - |
101 | | - flows += [ |
102 | | - VULKAN_TEST_FLOW, |
103 | | - VULKAN_STATIC_INT8_PER_CHANNEL_TEST_FLOW, |
104 | | - ] |
105 | | - except Exception as e: |
106 | | - logger.info(f"Skipping Vulkan flow registration: {e}") |
107 | 140 |
|
108 | | - try: |
109 | | - from executorch.backends.test.suite.flows.qualcomm import ( |
110 | | - QNN_16A16W_TEST_FLOW, |
111 | | - QNN_16A4W_BLOCK_TEST_FLOW, |
112 | | - QNN_16A4W_TEST_FLOW, |
113 | | - QNN_16A8W_TEST_FLOW, |
114 | | - QNN_8A8W_TEST_FLOW, |
115 | | - QNN_TEST_FLOW, |
116 | | - ) |
117 | | - |
118 | | - flows += [ |
119 | | - QNN_TEST_FLOW, |
120 | | - QNN_16A16W_TEST_FLOW, |
121 | | - QNN_16A8W_TEST_FLOW, |
122 | | - QNN_16A4W_TEST_FLOW, |
123 | | - QNN_16A4W_BLOCK_TEST_FLOW, |
124 | | - QNN_8A8W_TEST_FLOW, |
125 | | - ] |
126 | | - except Exception as e: |
127 | | - logger.info(f"Skipping QNN flow registration: {e}") |
| 141 | +def all_flows() -> dict[str, TestFlow]: |
| 142 | + from executorch.backends.test.suite.flows.portable import PORTABLE_TEST_FLOW |
128 | 143 |
|
129 | | - try: |
130 | | - from executorch.backends.test.suite.flows.arm import ( |
131 | | - ARM_ETHOS_U55_FLOW, |
132 | | - ARM_ETHOS_U85_FLOW, |
133 | | - ARM_TOSA_FP_FLOW, |
134 | | - ARM_TOSA_INT_FLOW, |
135 | | - ARM_VGF_FP_FLOW, |
136 | | - ARM_VGF_INT_FLOW, |
137 | | - ) |
138 | | - |
139 | | - flows += [ |
140 | | - ARM_TOSA_FP_FLOW, |
141 | | - ARM_TOSA_INT_FLOW, |
142 | | - ARM_ETHOS_U55_FLOW, |
143 | | - ARM_ETHOS_U85_FLOW, |
144 | | - ARM_VGF_FP_FLOW, |
145 | | - ARM_VGF_INT_FLOW, |
146 | | - ] |
147 | | - except Exception as e: |
148 | | - logger.info(f"Skipping ARM flow registration: {e}") |
| 144 | + flows = [PORTABLE_TEST_FLOW] |
| 145 | + |
| 146 | + for module_path, flow_names, backend_name in _FLOW_REGISTRY: |
| 147 | + flows.extend(_try_import_flows(module_path, flow_names, backend_name)) |
149 | 148 |
|
150 | 149 | return {f.name: f for f in flows if f is not None} |
0 commit comments