Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mlir/include/Driver/DefaultPipelines/DefaultPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ const PipelineList pipelineList{
"register-inactive-callback"}}};
// clang-format on

PipelineNames getPipelineNames()
inline PipelineNames getPipelineNames()
{
static std::vector<std::string> names =
std::accumulate(driver::pipelineList.begin(), driver::pipelineList.end(),
Expand All @@ -179,7 +179,7 @@ PipelineNames getPipelineNames()
return names;
}

PassNames getQuantumCompilationStage(bool disableAssertion = true)
inline PassNames getQuantumCompilationStage(bool disableAssertion = true)
{
PassNames ret;
std::copy_if(pipelineList[0].passNames.begin(), pipelineList[0].passNames.end(),
Expand All @@ -189,11 +189,11 @@ PassNames getQuantumCompilationStage(bool disableAssertion = true)
return ret;
}

PassNames getHLOLoweringStage() { return pipelineList[1].passNames; }
inline PassNames getHLOLoweringStage() { return pipelineList[1].passNames; }

PassNames getGradientLoweringStage() { return pipelineList[2].passNames; }
inline PassNames getGradientLoweringStage() { return pipelineList[2].passNames; }

PassNames getBufferizationStage(bool asyncQNodes = false)
inline PassNames getBufferizationStage(bool asyncQNodes = false)
{
const std::string bufferizationOptions =
std::string("{bufferize-function-boundaries ") + "allow-return-allocs-from-loops " +
Expand All @@ -211,7 +211,7 @@ PassNames getBufferizationStage(bool asyncQNodes = false)
return ret;
}

PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false)
inline PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false)
{
PassNames ret;
std::copy_if(pipelineList[4].passNames.begin(), pipelineList[4].passNames.end(),
Expand Down
64 changes: 64 additions & 0 deletions mlir/include/Remote/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,70 @@

include "mlir/Pass/PassBase.td"

def CrossCompileRemoteKernelsPass : Pass<"cross-compile-remote-kernels", "mlir::ModuleOp"> {
let summary = "Cross-compile each `qnode` to a `.o` and emit `remote.*` ops at the host call sites.";
let description = [{
For every `func.func` in the host module carrying the `qnode` attribute, this pass:
1. Clones the func into a standalone module.
2. Sanitizes it for lowering to LLVM IR and emits the `.o` to `<workspace>/<name>.o`.
3. Marks every `func.call` targeting the qnode with `catalyst.remote_kernel_path`
holding the absolute path of the produced `.o`.
}];

let dependentDialects = [
"mlir::LLVM::LLVMDialect",
"mlir::func::FuncDialect",
"catalyst::CatalystDialect",
"catalyst::remote::RemoteDialect",
];

let options = [
Option<
/*C++ var name=*/"workspace",
/*CLI arg name=*/"workspace",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Filesystem directory to write cross-compiled `.o` files into."
>,
Option<
/*C++ var name=*/"target",
/*CLI arg name=*/"target",
/*type=*/"std::string",
/*default=*/"\"x86_64\"",
/*description=*/
"LLVM target triple used for object emission."
>,
Option<
/*C++ var name=*/"cpu",
/*CLI arg name=*/"cpu",
/*type=*/"std::string",
/*default=*/"\"generic\"",
/*description=*/
"LLVM CPU model fed to `createTargetMachine`. Defaults to "
"`generic`, which emits baseline code for the triple."
>,
Option<
/*C++ var name=*/"features",
/*CLI arg name=*/"features",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Comma-separated subtarget features fed to `createTargetMachine` "
"(each token must be `+feat` or `-feat`, e.g. `+crc,+aes,+sha2`). "
"Empty (default) lets the CPU choose its own feature set."
>,
Option<
/*C++ var name=*/"address",
/*CLI arg name=*/"address",
/*type=*/"std::string",
/*default=*/"\"\"",
/*description=*/
"Executor's TCP `host:port` address for remote dispatch."
>
];
}

def ConvertRemoteToLLVMPass : Pass<"convert-remote-to-llvm", "mlir::ModuleOp"> {
let summary = "Lower the `remote` dialect to direct calls into the Catalyst remote runtime.";
let description = [{
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Remote/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
set(LIBRARY_NAME remote-transforms)

set(LLVM_LINK_COMPONENTS
AllTargetsAsmParsers
AllTargetsCodeGens
)

file(GLOB SRC
CrossCompileRemoteKernels.cpp
RemoteToLLVM.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
set(LIBS
${dialect_libs}
${conversion_libs}
${translation_libs}
MLIRRemote
MLIRCatalyst
)
Expand Down
Loading
Loading