Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def get_default_flags(options):
"-lrt_decoder",
]

rt_remote_so = "librt_remote" + file_extension
if os.path.isfile(os.path.join(rt_lib_path, rt_remote_so)):
default_flags.append("-lrt_remote")

# If OQD runtime capi is built, link to it as well
# TODO: This is not ideal and should be replaced when the compiler is device aware
if os.path.isfile(os.path.join(rt_lib_path, "librt_OQD_capi" + file_extension)):
Expand Down
7 changes: 4 additions & 3 deletions mlir/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ USE_SANITIZER_NAMES=""
USE_SANITIZER_FLAGS=""
endif

LLVM_PROJECTS ?= mlir
LLVM_TARGETS ?= check-mlir llvm-symbolizer
LLVM_PROJECTS ?= mlir;lld
LLVM_TARGETS ?= check-mlir llvm-symbolizer lld llvm-strip
LLVM_TARGETS_TO_BUILD ?= host;AArch64;X86

.PHONY: help
help:
Expand Down Expand Up @@ -75,7 +76,7 @@ llvm:
cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \
-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="host" \
-DLLVM_TARGETS_TO_BUILD="$(LLVM_TARGETS_TO_BUILD)" \
-DLLVM_ENABLE_PROJECTS="$(LLVM_PROJECTS)" \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
Expand Down
62 changes: 62 additions & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,68 @@ def ApplyTransformSequencePass : Pass<"apply-transform-sequence"> {
let summary = "Apply the passes scheduled with the transform dialect.";
}

def CrossCompileRemoteKernelsPass : Pass<"cross-compile-remote-kernels", "mlir::ModuleOp"> {
let summary = "Cross-compile each `qnode` to a `.o` and stash the path on the host call sites.";
let description = [{
For every `func.func` in the host module carrying the `qnode` attribute, this pass:
1. Clones the func into a standalone module.
2. Sanitizes it for lowering to LLVM IR and emits the `.o` to `<workspace>/<name>.o`.
3. Marks every `func.call` targeting the qnode with `catalyst.remote_kernel_path`
holding the absolute path of the produced `.o`.
}];

let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::func::FuncDialect",
];

let options = [
Option<
/*C++ var name=*/"workspace",
/*CLI arg name=*/"workspace",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Filesystem directory to write cross-compiled `.o` files into."
>,
Option<
/*C++ var name=*/"target",
/*CLI arg name=*/"target",
/*type=*/"std::string",
/*default=*/"\"x86_64\"",
/*description=*/
"LLVM target triple used for object emission."
>,
Option<
/*C++ var name=*/"cpu",
/*CLI arg name=*/"cpu",
/*type=*/"std::string",
/*default=*/"\"generic\"",
/*description=*/
"LLVM CPU model fed to `createTargetMachine`. Defaults to "
"`generic`, which emits baseline code for the triple."
>,
Option<
/*C++ var name=*/"features",
/*CLI arg name=*/"features",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Comma-separated subtarget features fed to `createTargetMachine` "
"(each token must be `+feat` or `-feat`, e.g. `+crc,+aes,+sha2`). "
"Empty (default) lets the CPU choose its own feature set."
>,
Option<
/*C++ var name=*/"address",
/*CLI arg name=*/"address",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Executor's TCP `host:port` address for remote dispatch."
>
];
}

def InlineNestedModulePass : Pass<"inline-nested-module"> {
let summary = "Inline nested modules with qnode attribute.";

Expand Down
12 changes: 6 additions & 6 deletions mlir/include/Driver/DefaultPipelines/DefaultPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ const PipelineList pipelineList{
"register-inactive-callback"}}};
// clang-format on

PipelineNames getPipelineNames()
inline PipelineNames getPipelineNames()
{
static std::vector<std::string> names =
std::accumulate(driver::pipelineList.begin(), driver::pipelineList.end(),
Expand All @@ -174,7 +174,7 @@ PipelineNames getPipelineNames()
return names;
}

PassNames getQuantumCompilationStage(bool disableAssertion = true)
inline PassNames getQuantumCompilationStage(bool disableAssertion = true)
{
PassNames ret;
std::copy_if(pipelineList[0].passNames.begin(), pipelineList[0].passNames.end(),
Expand All @@ -184,11 +184,11 @@ PassNames getQuantumCompilationStage(bool disableAssertion = true)
return ret;
}

PassNames getHLOLoweringStage() { return pipelineList[1].passNames; }
inline PassNames getHLOLoweringStage() { return pipelineList[1].passNames; }

PassNames getGradientLoweringStage() { return pipelineList[2].passNames; }
inline PassNames getGradientLoweringStage() { return pipelineList[2].passNames; }

PassNames getBufferizationStage(bool asyncQNodes = false)
inline PassNames getBufferizationStage(bool asyncQNodes = false)
{
const std::string bufferizationOptions =
std::string("{bufferize-function-boundaries ") + "allow-return-allocs-from-loops " +
Expand All @@ -206,7 +206,7 @@ PassNames getBufferizationStage(bool asyncQNodes = false)
return ret;
}

PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false)
inline PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false)
{
PassNames ret;
std::copy_if(pipelineList[4].passNames.begin(), pipelineList[4].passNames.end(),
Expand Down
9 changes: 6 additions & 3 deletions mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,13 @@ struct CustomCallOpInterface
// Add the initial number of arguments
int32_t numArguments = static_cast<int32_t>(customCallOp.getNumOperands());
IntegerAttr numArgumentsAttr = rewriter.getI32IntegerAttr(numArguments);
auto newOp = CustomCallOp::create(rewriter, op->getLoc(), TypeRange{}, bufferArgs,
customCallOp.getCallTargetName(), numArgumentsAttr);

// Create an updated custom call operation
CustomCallOp::create(rewriter, op->getLoc(), TypeRange{}, bufferArgs,
customCallOp.getCallTargetName(), numArgumentsAttr);
// carry over any discardable attributes
for (NamedAttribute attr : op->getDiscardableAttrs()) {
newOp->setAttr(attr.getName(), attr.getValue());
}
size_t startIndex = bufferArgs.size() - customCallOp.getNumResults();
SmallVector<Value> bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end());
bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults);
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Catalyst/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ file(GLOB SRC
BufferDeallocation.cpp
BufferizableOpInterfaceImpl.cpp
catalyst_to_llvm.cpp
CrossCompileRemoteKernels.cpp
DetectQNodes.cpp
DetensorizeFunctionBoundaryPass.cpp
DetensorizeSCFPass.cpp
Expand All @@ -29,11 +30,18 @@ file(GLOB SRC
TBAATagsPass.cpp
)

set(LLVM_LINK_COMPONENTS
AllTargetsAsmParsers
AllTargetsCodeGens
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
${translation_libs}
catalyst-analysis
)

Expand Down
Loading
Loading