|
| 1 | +# Copyright 2026 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +from types import SimpleNamespace |
| 7 | + |
| 8 | +from executorch.backends.arm.tosa import partitioner as tosa_partitioner |
| 9 | +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec |
| 10 | +from executorch.backends.arm.tosa.partitioner import TOSAPartitioner |
| 11 | + |
| 12 | + |
| 13 | +class _FakeCapabilityBasedPartitioner: |
| 14 | + def __init__(self, *args, **kwargs) -> None: |
| 15 | + pass |
| 16 | + |
| 17 | + def propose_partitions(self): |
| 18 | + return [ |
| 19 | + SimpleNamespace(nodes=[SimpleNamespace(meta={}, target=f"op{idx}")]) |
| 20 | + for idx in range(3) |
| 21 | + ] |
| 22 | + |
| 23 | + |
| 24 | +def _make_reporter() -> SimpleNamespace: |
| 25 | + return SimpleNamespace( |
| 26 | + report_reject=lambda *args, **kwargs: None, |
| 27 | + get_table_report=lambda: "", |
| 28 | + ) |
| 29 | + |
| 30 | + |
| 31 | +def test_tag_module_preserves_partition_discovery_order(monkeypatch): |
| 32 | + partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP")) |
| 33 | + |
| 34 | + monkeypatch.setattr( |
| 35 | + tosa_partitioner, "get_cond_while_submodules_nested", lambda module: [] |
| 36 | + ) |
| 37 | + monkeypatch.setattr( |
| 38 | + tosa_partitioner, "tosa_support_factory", lambda *args, **kwargs: object() |
| 39 | + ) |
| 40 | + monkeypatch.setattr( |
| 41 | + tosa_partitioner, |
| 42 | + "CapabilityBasedPartitioner", |
| 43 | + _FakeCapabilityBasedPartitioner, |
| 44 | + ) |
| 45 | + monkeypatch.setattr( |
| 46 | + partitioner, |
| 47 | + "_partition_has_invalid_uint8", |
| 48 | + lambda partition, tag: False, |
| 49 | + ) |
| 50 | + monkeypatch.setattr( |
| 51 | + partitioner, |
| 52 | + "_preserve_io_quantization_enabled", |
| 53 | + lambda: False, |
| 54 | + ) |
| 55 | + |
| 56 | + tags = partitioner._tag_module( |
| 57 | + SimpleNamespace(graph=SimpleNamespace(nodes=[])), |
| 58 | + SimpleNamespace(), |
| 59 | + _make_reporter(), |
| 60 | + ) |
| 61 | + |
| 62 | + assert tags == ["tag0", "tag1", "tag2"] |
| 63 | + |
| 64 | + |
| 65 | +def test_partition_preserves_tag_discovery_order(monkeypatch): |
| 66 | + partitioner = TOSAPartitioner(TosaCompileSpec("TOSA-1.0+FP")) |
| 67 | + |
| 68 | + monkeypatch.setattr( |
| 69 | + partitioner, |
| 70 | + "_tag_module", |
| 71 | + lambda *args, **kwargs: ["tag2", "tag10"], |
| 72 | + ) |
| 73 | + monkeypatch.setattr(tosa_partitioner, "tag_constant_data", lambda program: None) |
| 74 | + monkeypatch.setattr( |
| 75 | + tosa_partitioner, "WhyNoPartitionReporter", _make_reporter |
| 76 | + ) |
| 77 | + |
| 78 | + result = partitioner.partition(SimpleNamespace(graph_module=SimpleNamespace())) |
| 79 | + |
| 80 | + assert list(result.partition_tags) == ["tag2", "tag10"] |
0 commit comments