Skip to content

Commit cd985ac

Browse files
committed
add a shim for new and old MKL constants
allows mkl-service to work with both old and new versions
1 parent 1e45475 commit cd985ac

4 files changed

Lines changed: 205 additions & 16 deletions

File tree

conda-recipe/meta.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ requirements:
2222
- python-gil # [py>=314]
2323
- pip >=25.0
2424
- setuptools >=77
25-
- mkl-devel
25+
- mkl-devel <2024 # [osx]
26+
- mkl-devel # [not osx]
2627
- cython
2728
- wheel >=0.45.1
2829
- python-build >=1.2.2

mkl/_mkl_service.pxd

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ cdef extern from "mkl.h":
5959
int MKL_CBWR_AVX2
6060
int MKL_CBWR_AVX512
6161
int MKL_CBWR_AVX512_E1
62-
int MKL_CBWR_AVX10
6362

6463
int MKL_CBWR_SUCCESS
6564
int MKL_CBWR_ERR_INVALID_SETTINGS
@@ -74,12 +73,10 @@ cdef extern from "mkl.h":
7473
int MKL_ENABLE_AVX512_E3
7574
int MKL_ENABLE_AVX512_E4
7675
int MKL_ENABLE_AVX512_E1
77-
int MKL_ENABLE_AVX512_E5
7876
int MKL_ENABLE_AVX512
7977
int MKL_ENABLE_AVX2
8078
int MKL_ENABLE_AVX2_E1
8179
int MKL_ENABLE_SSE4_2
82-
int MKL_ENABLE_AVX10
8380

8481
# MPI Implementation Constants
8582
int MKL_BLACS_CUSTOM
@@ -169,3 +166,62 @@ cdef extern from "mkl.h":
169166
int vmlSetErrStatus(const MKL_INT status)
170167
int vmlGetErrStatus()
171168
int vmlClearErrStatus()
169+
170+
# version-compat shim
171+
cdef extern from *:
172+
"""
173+
#include <mkl.h>
174+
175+
/* define constants removed in 2026.0 if undefined */
176+
#ifndef MKL_CBWR_SSSE3
177+
#define MKL_CBWR_SSSE3 -1
178+
#endif
179+
#ifndef MKL_CBWR_SSE4_1
180+
#define MKL_CBWR_SSE4_1 -1
181+
#endif
182+
#ifndef MKL_CBWR_AVX
183+
#define MKL_CBWR_AVX -1
184+
#endif
185+
#ifndef MKL_CBWR_AVX512_MIC
186+
#define MKL_CBWR_AVX512_MIC -1
187+
#endif
188+
#ifndef MKL_CBWR_AVX512_MIC_E1
189+
#define MKL_CBWR_AVX512_MIC_E1 -1
190+
#endif
191+
192+
#ifndef MKL_ENABLE_AVX512_MIC_E1
193+
#define MKL_ENABLE_AVX512_MIC_E1 -1
194+
#endif
195+
#ifndef MKL_ENABLE_AVX512_MIC
196+
#define MKL_ENABLE_AVX512_MIC -1
197+
#endif
198+
#ifndef MKL_ENABLE_AVX
199+
#define MKL_ENABLE_AVX -1
200+
#endif
201+
202+
/* define constants from 2026.0 if undefined */
203+
#ifndef MKL_CBWR_AVX10
204+
#define MKL_CBWR_AVX10 -2
205+
#endif
206+
207+
#ifndef MKL_ENABLE_AVX512_E5
208+
#define MKL_ENABLE_AVX512_E5 -2
209+
#endif
210+
#ifndef MKL_ENABLE_AVX10
211+
#define MKL_ENABLE_AVX10 -2
212+
#endif
213+
"""
214+
int MKL_CBWR_SSSE3
215+
int MKL_CBWR_SSE4_1
216+
int MKL_CBWR_AVX
217+
int MKL_CBWR_AVX512_MIC
218+
int MKL_CBWR_AVX512_MIC_E1
219+
220+
int MKL_ENABLE_AVX512_MIC_E1
221+
int MKL_ENABLE_AVX512_MIC
222+
int MKL_ENABLE_AVX
223+
224+
int MKL_CBWR_AVX10
225+
226+
int MKL_ENABLE_AVX512_E5
227+
int MKL_ENABLE_AVX10

mkl/_mkl_service.pyx

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,6 @@ cdef object __cbwr_set(branch=None):
684684
"avx512,strict": mkl.MKL_CBWR_AVX512 | mkl.MKL_CBWR_STRICT,
685685
"avx512_e1": mkl.MKL_CBWR_AVX512_E1,
686686
"avx512_e1,strict": mkl.MKL_CBWR_AVX512_E1 | mkl.MKL_CBWR_STRICT,
687-
"avx10": mkl.MKL_CBWR_AVX10,
688-
"avx10,strict": mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT,
689687
},
690688
"output": {
691689
mkl.MKL_CBWR_SUCCESS: "success",
@@ -694,6 +692,27 @@ cdef object __cbwr_set(branch=None):
694692
mkl.MKL_CBWR_ERR_MODE_CHANGE_FAILURE: "err_mode_change_failure",
695693
},
696694
}
695+
# new CNR branches added in 2026.0
696+
if mkl.MKL_CBWR_AVX10 != -2:
697+
__variables["input"][mkl.MKL_CBWR_AVX10] = "avx10"
698+
__variables["input"][mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT] = "avx10,strict"
699+
# legacy branches removed in 2026.0
700+
if mkl.MKL_CBWR_SSSE3 != -1:
701+
__variables["input"]["ssse3"] = mkl.MKL_CBWR_SSSE3
702+
if mkl.MKL_CBWR_SSE4_1 != -1:
703+
__variables["input"]["sse4_1"] = mkl.MKL_CBWR_SSE4_1
704+
if mkl.MKL_CBWR_AVX != -1:
705+
__variables["input"]["avx"] = mkl.MKL_CBWR_AVX
706+
if mkl.MKL_CBWR_AVX512_MIC != -1:
707+
__variables["input"]["avx512_mic"] = mkl.MKL_CBWR_AVX512_MIC
708+
__variables["input"][
709+
"avx512_mic,strict"
710+
] = mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
711+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
712+
__variables["input"]["avx512_mic_e1"] = mkl.MKL_CBWR_AVX512_MIC_E1
713+
__variables["input"][
714+
"avx512_mic_e1,strict"
715+
] = mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
697716
mkl_branch = __mkl_str_to_int(branch, __variables["input"])
698717

699718
mkl_status = mkl.mkl_cbwr_set(mkl_branch)
@@ -723,11 +742,30 @@ cdef inline __cbwr_get(cnr_const=None):
723742
mkl.MKL_CBWR_AVX512 | mkl.MKL_CBWR_STRICT: "avx512,strict",
724743
mkl.MKL_CBWR_AVX512_E1: "avx512_e1",
725744
mkl.MKL_CBWR_AVX512_E1 | mkl.MKL_CBWR_STRICT: "avx512_e1,strict",
726-
mkl.MKL_CBWR_AVX10: "avx10",
727-
mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT: "avx10,strict",
728745
mkl.MKL_CBWR_ERR_INVALID_INPUT: "err_invalid_input",
729746
},
730747
}
748+
# new CNR branches added in 2026.0
749+
if mkl.MKL_CBWR_AVX10 != -2:
750+
__variables["output"][mkl.MKL_CBWR_AVX10] = "avx10"
751+
__variables["output"][mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT] = "avx10,strict"
752+
# legacy branches removed in 2026.0
753+
if mkl.MKL_CBWR_SSSE3 != -1:
754+
__variables["output"][mkl.MKL_CBWR_SSSE3] = "ssse3"
755+
if mkl.MKL_CBWR_SSE4_1 != -1:
756+
__variables["output"][mkl.MKL_CBWR_SSE4_1] = "sse4_1"
757+
if mkl.MKL_CBWR_AVX != -1:
758+
__variables["output"][mkl.MKL_CBWR_AVX] = "avx"
759+
if mkl.MKL_CBWR_AVX512_MIC != -1:
760+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC] = "avx512_mic"
761+
__variables["output"][
762+
mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
763+
] = "avx512_mic,strict"
764+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
765+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC_E1] = "avx512_mic_e1"
766+
__variables["output"][
767+
mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
768+
] = "avx512_mic_e1,strict"
731769
mkl_cnr_const = __mkl_str_to_int(cnr_const, __variables["input"])
732770

733771
mkl_status = mkl.mkl_cbwr_get(mkl_cnr_const)
@@ -753,12 +791,31 @@ cdef object __cbwr_get_auto_branch():
753791
mkl.MKL_CBWR_AVX512 | mkl.MKL_CBWR_STRICT: "avx512,strict",
754792
mkl.MKL_CBWR_AVX512_E1: "avx512_e1",
755793
mkl.MKL_CBWR_AVX512_E1 | mkl.MKL_CBWR_STRICT: "avx512_e1,strict",
756-
mkl.MKL_CBWR_AVX10: "avx10",
757-
mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT: "avx10,strict",
758794
mkl.MKL_CBWR_SUCCESS: "success",
759795
mkl.MKL_CBWR_ERR_INVALID_INPUT: "err_invalid_input",
760796
},
761797
}
798+
# new CNR branch added in 2026.0
799+
if mkl.MKL_CBWR_AVX10 != -2:
800+
__variables["output"][mkl.MKL_CBWR_AVX10] = "avx10"
801+
__variables["output"][mkl.MKL_CBWR_AVX10 | mkl.MKL_CBWR_STRICT] = "avx10,strict"
802+
# legacy CNR branches removed in 2026.0
803+
if mkl.MKL_CBWR_SSSE3 != -1:
804+
__variables["output"][mkl.MKL_CBWR_SSSE3] = "ssse3"
805+
if mkl.MKL_CBWR_SSE4_1 != -1:
806+
__variables["output"][mkl.MKL_CBWR_SSE4_1] = "sse4_1"
807+
if mkl.MKL_CBWR_AVX != -1:
808+
__variables["output"][mkl.MKL_CBWR_AVX] = "avx"
809+
if mkl.MKL_CBWR_AVX512_MIC != -1:
810+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC] = "avx512_mic"
811+
__variables["output"][
812+
mkl.MKL_CBWR_AVX512_MIC | mkl.MKL_CBWR_STRICT
813+
] = "avx512_mic,strict"
814+
if mkl.MKL_CBWR_AVX512_MIC_E1 != -1:
815+
__variables["output"][mkl.MKL_CBWR_AVX512_MIC_E1] = "avx512_mic_e1"
816+
__variables["output"][
817+
mkl.MKL_CBWR_AVX512_MIC_E1 | mkl.MKL_CBWR_STRICT
818+
] = "avx512_mic_e1,strict"
762819

763820
mkl_status = mkl.mkl_cbwr_get_auto_branch()
764821

@@ -779,14 +836,24 @@ cdef object __enable_instructions(isa=None):
779836
"avx512_e3": mkl.MKL_ENABLE_AVX512_E3,
780837
"avx512_e2": mkl.MKL_ENABLE_AVX512_E2,
781838
"avx512_e1": mkl.MKL_ENABLE_AVX512_E1,
782-
"avx512_e5": mkl.MKL_ENABLE_AVX512_E5,
783839
"avx512": mkl.MKL_ENABLE_AVX512,
784840
"avx2_e1": mkl.MKL_ENABLE_AVX2_E1,
785841
"avx2": mkl.MKL_ENABLE_AVX2,
786842
"sse4_2": mkl.MKL_ENABLE_SSE4_2,
787-
"avx10": mkl.MKL_ENABLE_AVX10,
788843
},
789844
}
845+
# new constants added in 2026.0
846+
if mkl.MKL_ENABLE_AVX512_E5 != -2:
847+
__variables["input"]["avx512_e5"] = mkl.MKL_ENABLE_AVX512_E5
848+
if mkl.MKL_ENABLE_AVX10 != -2:
849+
__variables["input"]["avx10"] = mkl.MKL_ENABLE_AVX10
850+
# legacy constants removed in 2026.0
851+
if mkl.MKL_ENABLE_AVX != -1:
852+
__variables["input"]["avx"] = mkl.MKL_ENABLE_AVX
853+
if mkl.MKL_ENABLE_AVX512_MIC != -1:
854+
__variables["input"]["avx512_mic"] = mkl.MKL_ENABLE_AVX512_MIC
855+
if mkl.MKL_ENABLE_AVX512_MIC_E1 != -1:
856+
__variables["input"]["avx512_mic_e1"] = mkl.MKL_ENABLE_AVX512_MIC_E1
790857
cdef int c_mkl_isa = __mkl_str_to_int(isa, __variables["input"])
791858

792859
cdef int c_mkl_status = mkl.mkl_enable_instructions(c_mkl_isa)

mkl/tests/test_mkl_service.py

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def check_cbwr(branch, cnr_const):
213213
pytest.fail(status)
214214

215215

216+
_mkl_major = mkl.get_version()["MajorVersion"]
217+
216218
branches = [
217219
"off",
218220
"branch_off",
@@ -223,15 +225,20 @@ def check_cbwr(branch, cnr_const):
223225
"avx2",
224226
"avx512",
225227
"avx512_e1",
226-
"avx10",
227228
]
228229

230+
# removed in MKL 2026.0
231+
legacy_cbwr_branches = ["ssse3", "sse4_1", "avx", "avx512_mic", "avx512_mic_e1"]
232+
legacy_cbwr_strict = ["avx512_mic,strict", "avx512_mic_e1,strict"]
233+
234+
# added in MKL 2026.0
235+
new_cbwr_branches = ["avx10"]
236+
new_cbwr_strict = ["avx10,strict"]
229237

230238
strict = [
231239
"avx2,strict",
232240
"avx512,strict",
233241
"avx512_e1,strict",
234-
"avx10,strict",
235242
]
236243

237244

@@ -249,26 +256,84 @@ def test_cbwr_get_auto_branch():
249256
mkl.cbwr_get_auto_branch()
250257

251258

259+
@pytest.mark.skipif(
260+
_mkl_major >= 2026,
261+
reason="Removed in MKL 2026.0",
262+
)
263+
@pytest.mark.parametrize("branch", legacy_cbwr_branches)
264+
def test_cbwr_branch_legacy(branch):
265+
check_cbwr(branch, "branch")
266+
267+
268+
@pytest.mark.skipif(
269+
_mkl_major >= 2026,
270+
reason="Removed in MKL 2026.0",
271+
)
272+
@pytest.mark.parametrize("branch", legacy_cbwr_branches + legacy_cbwr_strict)
273+
def test_cbwr_legacy(branch):
274+
check_cbwr(branch, "all")
275+
276+
277+
@pytest.mark.skipif(
278+
_mkl_major < 2026,
279+
reason="Added in MKL 2026.0",
280+
)
281+
@pytest.mark.parametrize("branch", new_cbwr_branches)
282+
def test_cbwr_branch_new(branch):
283+
check_cbwr(branch, "branch")
284+
285+
286+
@pytest.mark.skipif(
287+
_mkl_major < 2026,
288+
reason="Added in MKL 2026.0",
289+
)
290+
@pytest.mark.parametrize("branch", new_cbwr_branches + new_cbwr_strict)
291+
def test_cbwr_all_new(branch):
292+
check_cbwr(branch, "all")
293+
294+
252295
instructions = [
253296
"single_path",
254297
"avx512_e4",
255298
"avx512_e3",
256299
"avx512_e2",
257300
"avx512_e1",
258-
"avx512_e5",
259301
"avx512",
260302
"avx2_e1",
261303
"avx2",
262304
"sse4_2",
263-
"avx10",
264305
]
265306

307+
# removed in MKL 2026.0
308+
legacy_instructions = ["avx", "avx512_mic", "avx512_mic_e1"]
309+
310+
# added in MKL 2026.0
311+
new_instructions = ["avx512_e5", "avx10"]
312+
266313

267314
@pytest.mark.parametrize("isa", instructions)
268315
def test_enable_instructions(isa):
269316
mkl.enable_instructions(isa)
270317

271318

319+
@pytest.mark.skipif(
320+
_mkl_major >= 2026,
321+
reason="Removed in MKL 2026.0",
322+
)
323+
@pytest.mark.parametrize("isa", legacy_instructions)
324+
def test_enable_instructions_legacy(isa):
325+
mkl.enable_instructions(isa)
326+
327+
328+
@pytest.mark.skipif(
329+
_mkl_major < 2026,
330+
reason="Added in MKL 2026.0",
331+
)
332+
@pytest.mark.parametrize("isa", new_instructions)
333+
def test_enable_instructions_new(isa):
334+
mkl.enable_instructions(isa)
335+
336+
272337
def test_set_env_mode():
273338
mkl.set_env_mode()
274339

0 commit comments

Comments
 (0)