@@ -9,10 +9,13 @@ set(BASE_HEADERS
99 utils.h)
1010
1111function (build_kernel_base TARGET SRCFILE DEPS )
12- set (METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
12+ set (METAL_FLAGS -x metal - Wall -Wextra -fno-fast-math -Wno-c++17-extensions)
1313 if (MLX_METAL_DEBUG)
1414 set (METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources)
1515 endif ()
16+ if (MLX_ENABLE_NAX)
17+ set (METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0)
18+ endif ()
1619 if (NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "" )
1720 set (METAL_FLAGS ${METAL_FLAGS}
1821 "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET} " )
@@ -120,6 +123,30 @@ if(NOT MLX_METAL_JIT)
120123 build_kernel (gemv_masked steel/utils.h )
121124endif ()
122125
126+ if (MLX_ENABLE_NAX)
127+
128+ set (STEEL_NAX_HEADERS
129+ steel/defines.h
130+ steel/utils.h
131+ steel/gemm/transforms.h
132+ steel/gemm/nax.h
133+ steel/gemm/gemm_nax.h
134+ steel/utils/type_traits.h
135+ steel/utils/integral_constant.h)
136+
137+ build_kernel (steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS} )
138+ build_kernel (steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS} )
139+
140+ build_kernel (quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS} )
141+ build_kernel (fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS} )
142+
143+ set (STEEL_NAX_ATTN_HEADERS
144+ steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h
145+ steel/utils/integral_constant.h)
146+
147+ build_kernel (steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS} )
148+ endif ()
149+
123150add_custom_command (
124151 OUTPUT ${MLX_METAL_PATH} /mlx.metallib
125152 COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
0 commit comments