Skip to content

Commit d83f031

Browse files
fbarchardxnnpack-bot
authored andcommitted
Use immutabledict for global constants in python scripts
PiperOrigin-RevId: 914562268
1 parent ce14e18 commit d83f031

4 files changed

Lines changed: 38 additions & 20 deletions

File tree

tools/generate-vbinary-test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
import re
1212
import sys
1313
import yaml
14+
try:
15+
from immutabledict import immutabledict
16+
except ImportError:
17+
immutabledict = dict
1418

1519
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1620
import xngen
@@ -44,7 +48,7 @@
4448
)
4549
parser.set_defaults(defines=list())
4650

47-
OP_TYPES = {
51+
OP_TYPES = immutabledict({
4852
"vadd": "Add",
4953
"vaddc": "Add",
5054
"vcopysign": "CopySign",
@@ -71,7 +75,7 @@
7175
"vprelu": "Prelu",
7276
"vpreluc": "Prelu",
7377
"vrpreluc": "RPrelu",
74-
}
78+
})
7579

7680
BINOP_TEST_TEMPLATE = """
7781
#define XNN_UKERNEL(arch_flags, ukernel, batch_tile, vector_tile, datatype, params_type, init_params)

tools/generate-vunary-test.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import math
99
import os
1010
import sys
11-
import types
1211
from typing import NamedTuple
1312

13+
try:
14+
from immutabledict import immutabledict
15+
except ImportError:
16+
immutabledict = dict
17+
1418
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
1519
import xngen
1620
import xnncommon
@@ -41,7 +45,7 @@ class SpecialValues(NamedTuple):
4145
expected_outputs: str
4246
tolerance_ulp: int
4347

44-
OP_TYPES = {
48+
OP_TYPES = immutabledict({
4549
"vabs": "Abs",
4650
"vapproxgelu": "ApproxGELU",
4751
"vclamp": "Clamp",
@@ -63,11 +67,11 @@ class SpecialValues(NamedTuple):
6367
"vsqr": "Square",
6468
"vsqrt": "SquareRoot",
6569
"vtanh": "TanH",
66-
}
70+
})
6771

6872
PARAMS_TYPES = ["Clamp", "ELU", "LeakyReLU"]
6973

70-
SPECIAL_VALUES_BY_OP_TYPE_F32 = types.MappingProxyType({
74+
SPECIAL_VALUES_BY_OP_TYPE_F32 = immutabledict({
7175
"SquareRoot": SpecialValues(
7276
num_elements=4,
7377
inputs="{0.0f, -0.0f, 1.0f, -1.0f}",
@@ -127,7 +131,7 @@ class SpecialValues(NamedTuple):
127131
})
128132

129133

130-
SPECIAL_VALUES_BY_OP_TYPE_F16 = types.MappingProxyType({
134+
SPECIAL_VALUES_BY_OP_TYPE_F16 = immutabledict({
131135
"Log": SpecialValues(
132136
num_elements=4,
133137
inputs="{1.0f, -1.0f, 0.0f, -0.0f}",

tools/update-microkernels.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import os
1111
import re
1212
import sys
13+
try:
14+
from immutabledict import immutabledict
15+
except ImportError:
16+
immutabledict = dict
17+
1318

1419
parser = argparse.ArgumentParser(
1520
description='Utility for re-generating microkernel lists'
@@ -81,12 +86,12 @@
8186
'wasmsimd',
8287
})
8388

84-
_ISA_MAP = {
89+
_ISA_MAP = immutabledict({
8590
'wasmblendvps': 'wasmrelaxedsimd',
8691
'wasmpshufb': 'wasmrelaxedsimd',
8792
'wasmsdot': 'wasmrelaxedsimd',
8893
'wasmusdot': 'wasmrelaxedsimd',
89-
}
94+
})
9095

9196
_ARCH_LIST = frozenset({
9297
'aarch32',
@@ -101,12 +106,12 @@
101106
r'\bxnn_(?:[a-z0-9]+(?:_[a-z0-9]+)*)_ukernel(?:_[a-z0-9]+)*__(?:[a-z0-9]+(?:_[a-z0-9]+)*)\b'
102107
)
103108

104-
_VERIFICATION_IGNORE_SUBDIRS = {
109+
_VERIFICATION_IGNORE_SUBDIRS = frozenset({
105110
os.path.join('src', 'qs8-requantization'),
106111
os.path.join('src', 'qu8-requantization'),
107112
os.path.join('src', 'reference'),
108113
os.path.join('src', 'xnnpack', 'simd'),
109-
}
114+
})
110115

111116

112117
def overwrite_if_changed(filepath, content):

tools/xnncommon.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66

77
import codecs
88
import os
9+
try:
10+
from immutabledict import immutabledict
11+
except ImportError:
12+
immutabledict = dict
13+
914

1015

1116
def _indent(text):
@@ -23,7 +28,7 @@ def _remove_duplicate_newlines(text):
2328
return "\n".join(filtered_lines)
2429

2530

26-
_ARCH_TO_MACRO_MAP = {
31+
_ARCH_TO_MACRO_MAP = immutabledict({
2732
"aarch32": "XNN_ARCH_ARM",
2833
"aarch64": "XNN_ARCH_ARM64",
2934
"x86-32": "XNN_ARCH_X86",
@@ -38,11 +43,11 @@ def _remove_duplicate_newlines(text):
3843
"wasmsimd32": "XNN_ARCH_WASMSIMD",
3944
"wasmrelaxedsimd32": "XNN_ARCH_WASMRELAXEDSIMD",
4045
"wasmrelaxedsimdfp16": "XNN_ARCH_WASMRELAXEDSIMDFP16",
41-
}
46+
})
4247

4348
# Mapping from ISA extension to macro guarding build-time enabled/disabled
4449
# status for the ISA. Only ISAs that can be enabled/disabled have an entry.
45-
_ISA_TO_MACRO_MAP = {
50+
_ISA_TO_MACRO_MAP = immutabledict({
4651
"fp16arith": "XNN_ENABLE_ARM_FP16_SCALAR",
4752
"neonfp16arith": "XNN_ENABLE_ARM_FP16_VECTOR",
4853
"neonbf16": "XNN_ENABLE_ARM_BF16",
@@ -75,9 +80,9 @@ def _remove_duplicate_newlines(text):
7580
"avx512fp16": "XNN_ENABLE_AVX512FP16",
7681
"avx512bf16": "XNN_ENABLE_AVX512BF16",
7782
"hvx": "XNN_ENABLE_HVX",
78-
}
83+
})
7984

80-
_ISA_TO_ARCH_MAP = {
85+
_ISA_TO_ARCH_MAP = immutabledict({
8186
"armsimd32": ["aarch32"],
8287
"fp16arith": ["aarch32", "aarch64"],
8388
"neon": ["aarch32", "aarch64"],
@@ -124,9 +129,9 @@ def _remove_duplicate_newlines(text):
124129
"wasmsdot": ["wasmrelaxedsimd"],
125130
"wasmusdot": ["wasmrelaxedsimd"],
126131
"wasmblendvps": ["wasmrelaxedsimd"],
127-
}
132+
})
128133

129-
_ISA_TO_ARCH_FLAGS_MAP = {
134+
_ISA_TO_ARCH_FLAGS_MAP = immutabledict({
130135
"armsimd32": "xnn_arch_arm_v6",
131136
"fp16arith": "xnn_arch_arm_fp16_arith",
132137
"neon": "xnn_arch_arm_neon",
@@ -168,7 +173,7 @@ def _remove_duplicate_newlines(text):
168173
"wasmsdot": "xnn_arch_wasm_sdot",
169174
"wasmusdot": "xnn_arch_wasm_usdot",
170175
"wasmblendvps": "xnn_arch_wasm_blendvps",
171-
}
176+
})
172177

173178

174179
def isa_hierarchy_map():
@@ -268,7 +273,7 @@ def postprocess_test_case(test_case, arch, isa, assembly=False):
268273
"hvx",
269274
]
270275

271-
_ISA_HIERARCHY_MAP = {isa: v for v, isa in enumerate(_ISA_HIERARCHY)}
276+
_ISA_HIERARCHY_MAP = immutabledict({isa: v for v, isa in enumerate(_ISA_HIERARCHY)})
272277

273278

274279
def overwrite_if_changed(filepath, content):

0 commit comments

Comments
 (0)