Skip to content

Commit efe3f4c

Browse files
authored
Add cuSolverDx JIT fusion and solver projections (#1176)
* Add cuSolverDx JIT fusion and solver projections Add support for cuSolverDx within JIT path. Solver functions are different than most operators since many of them return multiple values, and this was not compatible with the MatX syntax in the past. We introduce a new projection syntax where multiple values can be used inside of a single statement. This works because in non-JIT mode the `mtie` syntax is used and there's no ambiguity, but in JIT mode none of the overhead from the operator is present since it's converted to a string.
1 parent a12442f commit efe3f4c

53 files changed

Lines changed: 7639 additions & 542 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

AGENTS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,13 @@ the @.devcontainer directory using the create_base_container.sh script. The pyth
99
## Compiling and Running
1010
For use in external projects, MatX simply needs to be included by `#include <matx.h>`. To build unit tests the use the CMake option `MATX_BUILD_TESTS=ON`, benchmarks
1111
`MATX_BUILD_BENCHMARKS=ON`, and examples `MATX_BUILD_EXAMPLES=ON`. Individual tests can be compiled separately via different targets output from CMake.
12+
13+
## Development Expectations
14+
Every new public function, operator, transform, or backend path should include accompanying unit tests and documentation updates. Tests should cover the supported
15+
types, ranks, batching behavior, and important error or fallback cases for the new behavior. Documentation should describe how to use the feature, required build
16+
options or dependencies, and any limitations that would affect users or future maintainers. Any change to operators should also verify
17+
@docs_input/executor_compatibility.rst and update it if executor support, limitations, or backend requirements changed.
18+
19+
When adding or changing accelerated backends such as CUDA, cuBLAS, cuSolver, cuFFT, or MathDx paths, preserve the existing non-accelerated behavior unless the task
20+
explicitly calls for a breaking change. Prefer focused tests that compare the new backend against an existing trusted MatX path, and include negative tests for
21+
unsupported shape, dtype, rank, or configuration combinations when applicable.

CMakeLists.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ rapids_cpm_cccl(
159159
)
160160

161161
target_link_libraries(matx INTERFACE CCCL::CCCL)
162+
if(CCCL_SOURCE_DIR)
163+
target_compile_definitions(matx INTERFACE MATX_CCCL_SOURCE_DIR="${CCCL_SOURCE_DIR}")
164+
endif()
162165

163166
# Set flags for compiling tests faster (only for nvcc)
164167
if (NOT CMAKE_CUDA_COMPILER_ID STREQUAL "Clang")
@@ -364,13 +367,28 @@ if (MATX_EN_JIT OR MATX_EN_MATHDX)
364367

365368
# Link NVRTC library
366369
target_link_libraries(matx INTERFACE CUDA::nvrtc)
370+
if(TARGET CUDA::nvJitLink)
371+
target_link_libraries(matx INTERFACE CUDA::nvJitLink)
372+
endif()
367373
endif()
368374

369375
if (MATX_EN_MATHDX)
370-
set(MathDx_VERSION 25.06)
376+
set(MathDx_VERSION 26.03)
371377
set(MathDx_NANO 0)
378+
set(LIBMATHDX_VERSION 0.3.2)
372379
include(cmake/FindMathDx.cmake)
373380
target_compile_definitions(matx INTERFACE MATX_EN_MATHDX)
381+
target_compile_definitions(matx INTERFACE
382+
MATX_MATHDX_INCLUDE_DIR="${MATHDX_INCLUDE_DIR}"
383+
MATX_MATHDX_CUTLASS_INCLUDE_DIR="${MATHDX_CUTLASS_INCLUDE_DIR}"
384+
MATX_LIBMATHDX_INCLUDE_DIR="${LIBMATHDX_INCLUDE_DIR}"
385+
)
386+
if(EXISTS "${MATHDX_ROOT}/lib/libcusolverdx.fatbin")
387+
target_compile_definitions(matx INTERFACE MATX_CUSOLVERDX_FATBIN="${MATHDX_ROOT}/lib/libcusolverdx.fatbin")
388+
endif()
389+
if(EXISTS "${MATHDX_ROOT}/lib/libcusolverdx.a")
390+
target_compile_definitions(matx INTERFACE MATX_CUSOLVERDX_LIBRARY="${MATHDX_ROOT}/lib/libcusolverdx.a")
391+
endif()
374392

375393
# Link libmathdx if available
376394
if(TARGET libmathdx::libmathdx)
@@ -379,7 +397,7 @@ if (MATX_EN_MATHDX)
379397
endif()
380398

381399
# Link mathdx components
382-
target_link_libraries(matx INTERFACE mathdx::cufftdx)
400+
target_link_libraries(matx INTERFACE mathdx::cufftdx mathdx::cublasdx_no_lto mathdx::cusolverdx)
383401
endif()
384402

385403
if (MATX_EN_CUDSS)

cmake/FindMathDx.cmake

Lines changed: 54 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#=============================================================================
2-
# Copyright (c) 2021, NVIDIA CORPORATION.
2+
# Copyright (c) 2021-2026, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -14,84 +14,25 @@
1414
# limitations under the License.
1515
#=============================================================================
1616

17-
#[=======================================================================[.rst:
18-
FindMathDx
19-
--------
20-
21-
Find MathDx
22-
23-
Imported targets
24-
^^^^^^^^^^^^^^^^
25-
26-
This module defines the following :prop_tgt:`IMPORTED` target(s):
27-
28-
``MathDx::MathDx``
29-
The MathDx library, if found.
30-
31-
Result variables
32-
^^^^^^^^^^^^^^^^
33-
34-
This module will set the following variables in your project:
35-
36-
``MathDx_FOUND``
37-
True if MathDx is found.
38-
``MathDx_INCLUDE_DIRS``
39-
The include directories needed to use MathDx.
40-
``MathDx_VERSION_STRING``
41-
The version of the MathDx library found. [OPTIONAL]
42-
43-
#]=======================================================================]
4417
set(MathDx_VERSION_FULL ${MathDx_VERSION}.${MathDx_NANO})
4518

46-
# Prefer using a Config module if it exists for this project
47-
set(MathDx_NO_CONFIG FALSE)
48-
if(NOT MathDx_NO_CONFIG)
49-
find_package(MathDx CONFIG QUIET HINTS ${MathDx_DIR})
50-
if(MathDx_FOUND)
51-
find_package_handle_standard_args(MathDx DEFAULT_MSG MathDx_CONFIG)
52-
return()
53-
endif()
19+
if(NOT CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
20+
message(FATAL_ERROR "MATX_EN_MATHDX requires CUDA 13.0 or newer for MathDx ${MathDx_VERSION_FULL}")
5421
endif()
5522

56-
find_path(MathDx_INCLUDE_DIR NAMES MathDx.h)
23+
set(MATHDX_CUDA_VERSION "cuda13")
24+
set(MATHDX_CUDA_SUFFIX "cuda13.0")
5725

58-
# Search for the MathDx library
59-
find_library(MathDx_LIBRARY
60-
NAMES MathDx mathdx
61-
HINTS ${MathDx_DIR}
62-
PATH_SUFFIXES lib lib64
63-
)
64-
65-
include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)
66-
67-
find_package_handle_standard_args(MathDx
68-
REQUIRED_VARS MathDx_LIBRARY MathDx_INCLUDE_DIR
69-
VERSION_VAR )
70-
71-
if(NOT MathDx_FOUND)
72-
set(MathDx_FILENAME libMathDx-linux-x86_64-${MathDx_VERSION}-archive)
73-
74-
message(STATUS "MathDx not found. Downloading library. By continuing this download you accept to the license terms of MathDx")
75-
76-
CPMAddPackage(
77-
NAME MathDx
78-
VERSION ${MathDx_VERSION}
79-
URL https://developer.download.nvidia.com/compute/cuFFTDx/redist/cuFFTDx/nvidia-mathdx-${MathDx_VERSION_FULL}.tar.gz
80-
DOWNLOAD_ONLY YES
81-
)
82-
endif()
26+
message(STATUS "Using MathDx ${MathDx_VERSION_FULL} (${MATHDX_CUDA_VERSION})")
27+
message(STATUS "Using libmathdx ${LIBMATHDX_VERSION} (${MATHDX_CUDA_VERSION})")
8328

84-
# Download libmathdx based on CUDA version and platform
85-
# Detect CUDA version (12 or 13)
86-
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
87-
set(LIBMATHDX_CUDA_VERSION "cuda13")
88-
set(LIBMATHDX_CUDA_SUFFIX "cuda13.0")
89-
else()
90-
set(LIBMATHDX_CUDA_VERSION "cuda12")
91-
set(LIBMATHDX_CUDA_SUFFIX "cuda12.0")
92-
endif()
29+
CPMAddPackage(
30+
NAME MathDx
31+
VERSION ${MathDx_VERSION_FULL}
32+
URL https://developer.nvidia.com/downloads/compute/cuSOLVERDx/redist/cuSOLVERDx/${MATHDX_CUDA_VERSION}/nvidia-mathdx-${MathDx_VERSION_FULL}-${MATHDX_CUDA_VERSION}.tar.gz
33+
DOWNLOAD_ONLY YES
34+
)
9335

94-
# Detect platform
9536
if(WIN32)
9637
set(LIBMATHDX_PLATFORM "win32-x86_64")
9738
set(LIBMATHDX_EXT "zip")
@@ -103,63 +44,51 @@ elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
10344
endif()
10445
set(LIBMATHDX_EXT "tar.gz")
10546
else()
106-
message(WARNING "Unsupported platform for libmathdx download")
47+
message(FATAL_ERROR "Unsupported platform for libmathdx download")
10748
endif()
10849

109-
# Set libmathdx version
110-
set(LIBMATHDX_VERSION "0.2.3")
111-
112-
# Download libmathdx if platform is supported
113-
if(DEFINED LIBMATHDX_PLATFORM)
114-
set(LIBMATHDX_URL "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/${LIBMATHDX_CUDA_VERSION}/libmathdx-${LIBMATHDX_PLATFORM}-${LIBMATHDX_VERSION}-${LIBMATHDX_CUDA_SUFFIX}.${LIBMATHDX_EXT}")
115-
116-
message(STATUS "Downloading libmathdx for ${LIBMATHDX_PLATFORM} with ${LIBMATHDX_CUDA_VERSION}")
117-
message(STATUS "libmathdx URL: ${LIBMATHDX_URL}")
118-
119-
CPMAddPackage(
120-
NAME libmathdx
121-
VERSION ${LIBMATHDX_VERSION}
122-
URL ${LIBMATHDX_URL}
123-
DOWNLOAD_ONLY YES
124-
)
125-
126-
# Add libmathdx to the search paths
127-
set(LIBMATHDX_ROOT "${PROJECT_BINARY_DIR}/_deps/libmathdx-src")
128-
list(APPEND CMAKE_PREFIX_PATH "${LIBMATHDX_ROOT}")
129-
130-
# Find libmathdx library file
131-
find_library(LIBMATHDX_LIBRARY
132-
NAMES mathdx libmathdx
133-
PATHS "${LIBMATHDX_ROOT}/lib"
134-
NO_DEFAULT_PATH
50+
set(LIBMATHDX_URL "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/${MATHDX_CUDA_VERSION}/libmathdx-${LIBMATHDX_PLATFORM}-${LIBMATHDX_VERSION}-${MATHDX_CUDA_SUFFIX}.${LIBMATHDX_EXT}")
51+
52+
message(STATUS "libmathdx URL: ${LIBMATHDX_URL}")
53+
54+
CPMAddPackage(
55+
NAME libmathdx
56+
VERSION ${LIBMATHDX_VERSION}
57+
URL ${LIBMATHDX_URL}
58+
DOWNLOAD_ONLY YES
59+
)
60+
61+
set(MATX_MATHDX_ROOT "${PROJECT_BINARY_DIR}/_deps/mathdx-src/nvidia/mathdx/${MathDx_VERSION}")
62+
set(MATHDX_INCLUDE_DIR "${MATX_MATHDX_ROOT}/include")
63+
set(MATHDX_CUTLASS_INCLUDE_DIR "${MATX_MATHDX_ROOT}/external/cutlass/include")
64+
set(LIBMATHDX_ROOT "${PROJECT_BINARY_DIR}/_deps/libmathdx-src")
65+
set(LIBMATHDX_INCLUDE_DIR "${LIBMATHDX_ROOT}/include")
66+
67+
find_library(LIBMATHDX_LIBRARY
68+
NAMES mathdx libmathdx
69+
PATHS "${LIBMATHDX_ROOT}/lib"
70+
NO_DEFAULT_PATH
71+
)
72+
73+
if(NOT LIBMATHDX_LIBRARY OR NOT EXISTS "${LIBMATHDX_INCLUDE_DIR}/libmathdx.h")
74+
message(FATAL_ERROR "Could not find libmathdx ${LIBMATHDX_VERSION} library or headers after download")
75+
endif()
76+
77+
if(NOT TARGET libmathdx::libmathdx)
78+
add_library(libmathdx::libmathdx INTERFACE IMPORTED)
79+
set_target_properties(libmathdx::libmathdx PROPERTIES
80+
INTERFACE_INCLUDE_DIRECTORIES "${LIBMATHDX_INCLUDE_DIR}"
81+
INTERFACE_LINK_LIBRARIES "${LIBMATHDX_LIBRARY}"
13582
)
136-
137-
# Set include directories (in both local and parent scope)
138-
set(LIBMATHDX_INCLUDE_DIR "${LIBMATHDX_ROOT}/include")
139-
set(LIBMATHDX_INCLUDE_DIR "${LIBMATHDX_INCLUDE_DIR}" PARENT_SCOPE)
140-
141-
if(LIBMATHDX_LIBRARY AND EXISTS ${LIBMATHDX_INCLUDE_DIR})
142-
message(STATUS "Found libmathdx library: ${LIBMATHDX_LIBRARY}")
143-
message(STATUS "Found libmathdx include dir: ${LIBMATHDX_INCLUDE_DIR}")
144-
145-
# Create libmathdx target
146-
if(NOT TARGET libmathdx::libmathdx)
147-
add_library(libmathdx::libmathdx INTERFACE IMPORTED)
148-
set_target_properties(libmathdx::libmathdx PROPERTIES
149-
INTERFACE_INCLUDE_DIRECTORIES "${LIBMATHDX_INCLUDE_DIR}"
150-
INTERFACE_LINK_LIBRARIES "${LIBMATHDX_LIBRARY}"
151-
)
152-
endif()
153-
else()
154-
message(WARNING "Could not find libmathdx library or include directory after download")
155-
endif()
15683
endif()
15784

158-
find_package(mathdx REQUIRED COMPONENTS cufftdx CONFIG
159-
PATHS
160-
"${PROJECT_BINARY_DIR}/_deps/mathdx-src/nvidia/mathdx/${MathDx_VERSION}/lib/cmake/mathdx/"
161-
"${PROJECT_BINARY_DIR}/_deps/libmathdx-src/lib/cmake/libmathdx/"
162-
"${PROJECT_BINARY_DIR}/_deps/libmathdx-src"
163-
"/opt/nvidia/mathdx/${MathDx_VERSION_FULL}"
85+
set(cublasdx_CUTLASS_ROOT "${MATX_MATHDX_ROOT}/external/cutlass")
86+
set(cusolverdx_CUTLASS_ROOT "${MATX_MATHDX_ROOT}/external/cutlass")
87+
88+
find_package(mathdx REQUIRED COMPONENTS cufftdx cublasdx cusolverdx CONFIG
89+
PATHS
90+
"${MATX_MATHDX_ROOT}/lib/cmake/mathdx"
91+
"/opt/nvidia/mathdx/${MathDx_VERSION}"
16492
)
16593

94+
set(MATHDX_ROOT "${MATX_MATHDX_ROOT}")

docs_input/api/linalg/decomp/chol.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ Perform a Cholesky factorization.
88
.. note::
99
The input matrix must be symmetric positive-definite
1010

11+
.. note::
12+
CUDA JIT kernel fusion is supported by cuSolverDx if ``-DMATX_EN_MATHDX=ON`` is enabled. The current JIT path
13+
supports rank 2 through 4 square matrices with ``float``, ``double``, ``complex<float>``, and
14+
``complex<double>`` values. Unsupported ranks, shapes, or data types should use a normal executor and will be
15+
rejected by ``CUDAJITExecutor``.
16+
1117
.. versionadded:: 0.6.0
1218

1319
.. doxygenfunction:: chol

docs_input/api/linalg/decomp/inverse.rst

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ Compute the inverse of a square matrix.
1111
.. note::
1212
This function is currently not supported with host-based executors (CPU)
1313

14+
.. note::
15+
CUDA JIT kernel fusion is supported by cuSolverDx if ``-DMATX_EN_MATHDX=ON`` is enabled. The current JIT path
16+
supports rank 2 through 4 square matrices with ``float``, ``double``, ``complex<float>``, and
17+
``complex<double>`` values. Unsupported ranks, shapes, or data types should use a normal executor and will be
18+
rejected by ``CUDAJITExecutor``.
19+
20+
.. note::
21+
The default inverse uses LU/GETRF-style factorization for general square matrices. For Hermitian
22+
positive-definite inputs, ``inv<MAT_INVERSE_ALGO_POSV>(A)`` enables a cuSolverDx POSV JIT path that solves
23+
``A * X = I`` directly and can fuse with other compatible JIT operators such as cuBLASDx matmul. POSV is
24+
currently a MathDx/CUDAJITExecutor-only inverse algorithm, and callers must ensure the input satisfies the
25+
positive-definite contract.
26+
1427
.. versionadded:: 0.6.0
1528

1629
.. doxygenfunction:: inv(const OpA &a)
@@ -23,5 +36,3 @@ Examples
2336
:start-after: example-begin inv-test-1
2437
:end-before: example-end inv-test-1
2538
:dedent:
26-
27-

docs_input/api/linalg/decomp/lu.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ Perform an LU factorization.
77

88
.. versionadded:: 0.6.0
99

10+
.. note::
11+
The ``mtie`` assignment form of ``lu`` uses the normal non-JIT solver path. CUDA JIT fusion is available through
12+
lazy projection members such as ``lu(A).LU`` and ``lu(A).Piv`` when ``-DMATX_EN_MATHDX=ON`` is enabled and the
13+
runtime shape/type is supported by cuSolverDx. Projection JIT currently supports ranks 2 through 4 and
14+
``float``, ``double``, ``complex<float>``, and ``complex<double>`` inputs.
15+
1016
.. doxygenfunction:: lu
1117

1218
Examples
@@ -17,3 +23,15 @@ Examples
1723
:start-after: example-begin lu-test-1
1824
:end-before: example-end lu-test-1
1925
:dedent:
26+
27+
Projection Examples
28+
~~~~~~~~~~~~~~~~~~~
29+
30+
Lazy projections let ``LU`` and ``Piv`` participate in a larger expression. This example is included from the
31+
``CuSolverDxPivotProjectionUsedInFusedExpression`` unit test.
32+
33+
.. literalinclude:: ../../../../test/00_solver/LU.cu
34+
:language: cpp
35+
:start-after: example-begin lu-projection-test-1
36+
:end-before: example-end lu-projection-test-1
37+
:dedent:

0 commit comments

Comments
 (0)