Skip to content

Commit 22dd3d9

Browse files
committed
add a shim for new and old MKL constants
allows mkl-service to work with both old and new versions
1 parent 1e45475 commit 22dd3d9

4 files changed

Lines changed: 198 additions & 12 deletions

File tree

conda-recipe/meta.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ requirements:
2222
- python-gil # [py>=314]
2323
- pip >=25.0
2424
- setuptools >=77
25-
- mkl-devel
25+
- mkl-devel <2024 # [osx]
26+
- mkl-devel # [not osx]
2627
- cython
2728
- wheel >=0.45.1
2829
- python-build >=1.2.2

mkl/_mkl_service.pxd

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ cdef extern from "mkl.h":
5959
int MKL_CBWR_AVX2
6060
int MKL_CBWR_AVX512
6161
int MKL_CBWR_AVX512_E1
62-
int MKL_CBWR_AVX10
6362

6463
int MKL_CBWR_SUCCESS
6564
int MKL_CBWR_ERR_INVALID_SETTINGS
@@ -74,12 +73,10 @@ cdef extern from "mkl.h":
7473
int MKL_ENABLE_AVX512_E3
7574
int MKL_ENABLE_AVX512_E4
7675
int MKL_ENABLE_AVX512_E1
77-
int MKL_ENABLE_AVX512_E5
7876
int MKL_ENABLE_AVX512
7977
int MKL_ENABLE_AVX2
8078
int MKL_ENABLE_AVX2_E1
8179
int MKL_ENABLE_SSE4_2
82-
int MKL_ENABLE_AVX10
8380

8481
# MPI Implementation Constants
8582
int MKL_BLACS_CUSTOM
@@ -169,3 +166,62 @@ cdef extern from "mkl.h":
169166
int vmlSetErrStatus(const MKL_INT status)
170167
int vmlGetErrStatus()
171168
int vmlClearErrStatus()
169+
170+
# version-compat shim
171+
cdef extern from *:
172+
"""
173+
#include <mkl.h>
174+
175+
/* define constants removed in 2026.0 if undefined */
176+
#ifndef MKL_CBWR_SSSE3
177+
#define MKL_CBWR_SSSE3 -1
178+
#endif
179+
#ifndef MKL_CBWR_SSE4_1
180+
#define MKL_CBWR_SSE4_1 -1
181+
#endif
182+
#ifndef MKL_CBWR_AVX
183+
#define MKL_CBWR_AVX -1
184+
#endif
185+
#ifndef MKL_CBWR_AVX512_MIC
186+
#define MKL_CBWR_AVX512_MIC -1
187+
#endif
188+
#ifndef MKL_CBWR_AVX512_MIC_E1
189+
#define MKL_CBWR_AVX512_MIC_E1 -1
190+
#endif
191+
192+
#ifndef MKL_ENABLE_AVX512_MIC_E1
193+
#define MKL_ENABLE_AVX512_MIC_E1 -1
194+
#endif
195+
#ifndef MKL_ENABLE_AVX512_MIC
196+
#define MKL_ENABLE_AVX512_MIC -1
197+
#endif
198+
#ifndef MKL_ENABLE_AVX
199+
#define MKL_ENABLE_AVX -1
200+
#endif
201+
202+
/* define constants from 2026.0 if undefined */
203+
#ifndef MKL_CBWR_AVX10
204+
#define MKL_CBWR_AVX10 -2
205+
#endif
206+
207+
#ifndef MKL_ENABLE_AVX512_E5
208+
#define MKL_ENABLE_AVX512_E5 -2
209+
#endif
210+
#ifndef MKL_ENABLE_AVX10
211+
#define MKL_ENABLE_AVX10 -2
212+
#endif
213+
"""
214+
int MKL_CBWR_SSSE3
215+
int MKL_CBWR_SSE4_1
216+
int MKL_CBWR_AVX
217+
int MKL_CBWR_AVX512_MIC
218+
int MKL_CBWR_AVX512_MIC_E1
219+
220+
int MKL_ENABLE_AVX512_MIC_E1
221+
int MKL_ENABLE_AVX512_MIC
222+
int MKL_ENABLE_AVX
223+
224+
int MKL_CBWR_AVX10
225+
226+
int MKL_ENABLE_AVX512_E5
227+
int MKL_ENABLE_AVX10

mkl/_mkl_service.pyx

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,23 @@ cdef object __cbwr_set(branch=None):
694694
mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE: "err_mode_change_failure",
695695
},
696696
}
697+
# legacy branches removed in 2026.0
698+
if mkl.MKL_CBWR_SSSE3 != -1:
699+
__variables["input"]["ssse3"] = mkl.MKL_CBWR_SSSE3
700+
if mkl.MKL_CBWR_SSE4_1 != -1:
701+
__variables["input"]["sse4_1"] = mkl.MKL_CBWR_SSE4_1
702+
if mkl.MKL_CBWR_AVX != -1:
703+
__variables["input"]["avx"] = mkl.MKL_CBWR_AVX
704+
if mkl.MKL_CBWR_AVX512_MIC != -1:
705+
__variables["input"]["avx512_mic"] = mkl.MKL_CBWR_AVX512_MIC
706+
__variables["input"][
707+
"avx512_mic,strict"
708+
] = mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
709+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
710+
__variables["input"]["avx512_mic_e1"] = mkl.MKL_CBWR_AVX512_MIC_E1
711+
__variables["input"][
712+
"avx512_mic_e1,strict"
713+
] = mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
697714
mkl_branch = __mkl_str_to_int(branch, __variables["input"])
698715

699716
mkl_status = mkl.mkl_cbwr_set(mkl_branch)
@@ -728,6 +745,23 @@ cdef inline __cbwr_get(cnr_const=None):
728745
mkl.MKL_CBWR_ERR_INVALID_INPUT: "err_invalid_input",
729746
},
730747
}
748+
# legacy branches removed in 2026.0
749+
if mkl.MKL_CBWR_SSSE3 != -1:
750+
__variables["output"][mkl.MKL_CBWR_SSSE3] = "ssse3"
751+
if mkl.MKL_CBWR_SSE4_1 != -1:
752+
__variables["output"][mkl.MKL_CBWR_SSE4_1] = "sse4_1"
753+
if mkl.MKL_CBWR_AVX != -1:
754+
__variables["output"][mkl.MKL_CBWR_AVX] = "avx"
755+
if mkl.MKL_CBWR_AVX512_MIC != -1:
756+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC] = "avx512_mic"
757+
__variables["output"][
758+
mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
759+
] = "avx512_mic,strict"
760+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
761+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC_E1] = "avx512_mic_e1"
762+
__variables["output"][
763+
mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
764+
] = "avx512_mic_e1,strict"
731765
mkl_cnr_const = __mkl_str_to_int(cnr_const, __variables["input"])
732766

733767
mkl_status = mkl.mkl_cbwr_get(mkl_cnr_const)
@@ -753,12 +787,32 @@ cdef object __cbwr_get_auto_branch():
753787
mkl.MKL_CBWR_AVX512 | mkl.MKL_CBWR_STRICT: "avx512,strict",
754788
mkl.MKL_CBWR_AVX512_E1: "avx512_e1",
755789
mkl.MKL_CBWR_AVX512_E1 | mkl.MKL_CBWR_STRICT: "avx512_e1,strict",
756-
mkl.MKL_CBWR_AVX10: "avx10",
757-
mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT: "avx10,strict",
758790
mkl.MKL_CBWR_SUCCESS: "success",
759791
mkl.MKL_CBWR_ERR_INVALID_INPUT: "err_invalid_input",
760792
},
761793
}
794+
# new CNR branches added in 2026.0
795+
if mkl.MKL_CBWR_AVX10 != -2:
796+
__variables["output"][mkl.MKL_CBWR_AVX10] = "avx10"
797+
__variables["output"][mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT] = "avx10,strict"
798+
799+
# legacy CNR branches removed in 2026.0
800+
if mkl.MKL_CBWR_SSSE3 != -1:
801+
__variables["output"][mkl.MKL_CBWR_SSSE3] = "ssse3"
802+
if mkl.MKL_CBWR_SSE4_1 != -1:
803+
__variables["output"][mkl.MKL_CBWR_SSE4_1] = "sse4_1"
804+
if mkl.MKL_CBWR_AVX != -1:
805+
__variables["output"][mkl.MKL_CBWR_AVX] = "avx"
806+
if mkl.MKL_CBWR_AVX512_MIC != -1:
807+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC] = "avx512_mic"
808+
__variables["output"][
809+
mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
810+
] = "avx512_mic,strict"
811+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
812+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC_E1] = "avx512_mic_e1"
813+
__variables["output"][
814+
mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
815+
] = "avx512_mic_e1,strict"
762816

763817
mkl_status = mkl.mkl_cbwr_get_auto_branch()
764818

@@ -779,14 +833,24 @@ cdef object __enable_instructions(isa=None):
779833
"avx512_e3": mkl.MKL_ENABLE_AVX512_E3,
780834
"avx512_e2": mkl.MKL_ENABLE_AVX512_E2,
781835
"avx512_e1": mkl.MKL_ENABLE_AVX512_E1,
782-
"avx512_e5": mkl.MKL_ENABLE_AVX512_E5,
783836
"avx512": mkl.MKL_ENABLE_AVX512,
784837
"avx2_e1": mkl.MKL_ENABLE_AVX2_E1,
785838
"avx2": mkl.MKL_ENABLE_AVX2,
786839
"sse4_2": mkl.MKL_ENABLE_SSE4_2,
787-
"avx10": mkl.MKL_ENABLE_AVX10,
788840
},
789841
}
842+
# new constants added in 2026.0
843+
if mkl.MKL_ENABLE_AVX512_E5 != -2:
844+
__variables["input"]["avx512_e5"] = mkl.MKL_ENABLE_AVX512_E5
845+
if mkl.MKL_ENABLE_AVX10 != -2:
846+
__variables["input"]["avx10"] = mkl.MKL_ENABLE_AVX10
847+
# legacy constants removed in 2026.0
848+
if mkl.MKL_ENABLE_AVX != -1:
849+
__variables["input"]["avx"] = mkl.MKL_ENABLE_AVX
850+
if mkl.MKL_ENABLE_AVX512_MIC != -1:
851+
__variables["input"]["avx512_mic"] = mkl.MKL_ENABLE_AVX512_MIC
852+
if mkl.MKL_ENABLE_AVX512_MIC_E1 != -1:
853+
__variables["input"]["avx512_mic_e1"] = mkl.MKL_ENABLE_AVX512_MIC_E1
790854
cdef int c_mkl_isa = __mkl_str_to_int(isa, __variables["input"])
791855

792856
cdef int c_mkl_status = mkl.mkl_enable_instructions(c_mkl_isa)

mkl/tests/test_mkl_service.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def check_cbwr(branch, cnr_const):
213213
pytest.fail(status)
214214

215215

216+
_mkl_major = mkl.get_version()["MajorVersion"]
217+
216218
branches = [
217219
"off",
218220
"branch_off",
@@ -223,15 +225,20 @@ def check_cbwr(branch, cnr_const):
223225
"avx2",
224226
"avx512",
225227
"avx512_e1",
226-
"avx10",
227228
]
228229

230+
# removed in MKL 2026.0
231+
legacy_cbwr_branches = ["ssse3", "sse4_1", "avx", "avx512_mic", "avx512_mic_e1"]
232+
legacy_cbwr_strict = ["avx512_mic,strict", "avx512_mic_e1,strict"]
233+
234+
# added in MKL 2026.0
235+
new_cbwr_branches = ["avx10"]
236+
new_cbwr_strict = ["avx10,strict"]
229237

230238
strict = [
231239
"avx2,strict",
232240
"avx512,strict",
233241
"avx512_e1,strict",
234-
"avx10,strict",
235242
]
236243

237244

@@ -249,26 +256,84 @@ def test_cbwr_get_auto_branch():
249256
mkl.cbwr_get_auto_branch()
250257

251258

259+
@pytest.mark.skipif(
260+
_mkl_major >= 2026,
261+
reason="Removed in MKL 2026.0",
262+
)
263+
@pytest.mark.parametrize("branch", legacy_cbwr_branches)
264+
def test_cbwr_branch_legacy(branch):
265+
check_cbwr(branch, "branch")
266+
267+
268+
@pytest.mark.skipif(
269+
_mkl_major >= 2026,
270+
reason="Removed in MKL 2026.0",
271+
)
272+
@pytest.mark.parametrize("branch", legacy_cbwr_branches + legacy_cbwr_strict)
273+
def test_cbwr_legacy(branch):
274+
check_cbwr(branch, "all")
275+
276+
277+
@pytest.mark.skipif(
278+
_mkl_major < 2026,
279+
reason="Added in MKL 2026.0",
280+
)
281+
@pytest.mark.parametrize("branch", new_cbwr_branches)
282+
def test_cbwr_branch_new(branch):
283+
check_cbwr(branch, "branch")
284+
285+
286+
@pytest.mark.skipif(
287+
_mkl_major < 2026,
288+
reason="Added in MKL 2026.0",
289+
)
290+
@pytest.mark.parametrize("branch", new_cbwr_branches + new_cbwr_strict)
291+
def test_cbwr_all_new(branch):
292+
check_cbwr(branch, "all")
293+
294+
252295
instructions = [
253296
"single_path",
254297
"avx512_e4",
255298
"avx512_e3",
256299
"avx512_e2",
257300
"avx512_e1",
258-
"avx512_e5",
259301
"avx512",
260302
"avx2_e1",
261303
"avx2",
262304
"sse4_2",
263-
"avx10",
264305
]
265306

307+
# removed in MKL 2026.0
308+
legacy_instructions = ["avx", "avx512_mic", "avx512_mic_e1"]
309+
310+
# added in MKL 2026.0
311+
new_instructions = ["avx512_e5", "avx10"]
312+
266313

267314
@pytest.mark.parametrize("isa", instructions)
268315
def test_enable_instructions(isa):
269316
mkl.enable_instructions(isa)
270317

271318

319+
@pytest.mark.skipif(
320+
_mkl_major >= 2026,
321+
reason="Removed in MKL 2026.0",
322+
)
323+
@pytest.mark.parametrize("isa", legacy_instructions)
324+
def test_enable_instructions_legacy(isa):
325+
mkl.enable_instructions(isa)
326+
327+
328+
@pytest.mark.skipif(
329+
_mkl_major < 2026,
330+
reason="Added in MKL 2026.0",
331+
)
332+
@pytest.mark.parametrize("isa", new_instructions)
333+
def test_enable_instructions_new(isa):
334+
mkl.enable_instructions(isa)
335+
336+
272337
def test_set_env_mode():
273338
mkl.set_env_mode()
274339

0 commit comments

Comments
 (0)