diff --git a/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h b/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h index 7430913cdb..b6d3908f68 100644 --- a/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h +++ b/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h @@ -168,7 +168,7 @@ const PipelineList pipelineList{ "register-inactive-callback"}}}; // clang-format on -PipelineNames getPipelineNames() +inline PipelineNames getPipelineNames() { static std::vector names = std::accumulate(driver::pipelineList.begin(), driver::pipelineList.end(), @@ -179,7 +179,7 @@ PipelineNames getPipelineNames() return names; } -PassNames getQuantumCompilationStage(bool disableAssertion = true) +inline PassNames getQuantumCompilationStage(bool disableAssertion = true) { PassNames ret; std::copy_if(pipelineList[0].passNames.begin(), pipelineList[0].passNames.end(), @@ -189,11 +189,11 @@ PassNames getQuantumCompilationStage(bool disableAssertion = true) return ret; } -PassNames getHLOLoweringStage() { return pipelineList[1].passNames; } +inline PassNames getHLOLoweringStage() { return pipelineList[1].passNames; } -PassNames getGradientLoweringStage() { return pipelineList[2].passNames; } +inline PassNames getGradientLoweringStage() { return pipelineList[2].passNames; } -PassNames getBufferizationStage(bool asyncQNodes = false) +inline PassNames getBufferizationStage(bool asyncQNodes = false) { const std::string bufferizationOptions = std::string("{bufferize-function-boundaries ") + "allow-return-allocs-from-loops " + @@ -211,7 +211,7 @@ PassNames getBufferizationStage(bool asyncQNodes = false) return ret; } -PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false) +inline PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false) { PassNames ret; std::copy_if(pipelineList[4].passNames.begin(), pipelineList[4].passNames.end(), diff --git a/mlir/include/Remote/Transforms/Passes.td b/mlir/include/Remote/Transforms/Passes.td index c37a9d6215..a6fac00f32 100644 --- a/mlir/include/Remote/Transforms/Passes.td +++ b/mlir/include/Remote/Transforms/Passes.td @@ -17,6 +17,70 @@ include "mlir/Pass/PassBase.td" +def CrossCompileRemoteKernelsPass : Pass<"cross-compile-remote-kernels", "mlir::ModuleOp"> { + let summary = "Cross-compile each `qnode` to a `.o` and emit `remote.*` ops at the host call sites."; + let description = [{ + For every `func.func` in the host module carrying the `qnode` attribute, this pass: + 1. Clones the func into a standalone module. + 2. Sanitizes it for lowering to LLVM IR and emits the `.o` to `/.o`. + 3. Marks every `func.call` targeting the qnode with `catalyst.remote_kernel_path` + holding the absolute path of the produced `.o`. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::func::FuncDialect", + "catalyst::CatalystDialect", + "catalyst::remote::RemoteDialect", + ]; + + let options = [ + Option< + /*C++ var name=*/"workspace", + /*CLI arg name=*/"workspace", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Filesystem directory to write cross-compiled `.o` files into." + >, + Option< + /*C++ var name=*/"target", + /*CLI arg name=*/"target", + /*type=*/"std::string", + /*default=*/"\"x86_64\"", + /*description=*/ + "LLVM target triple used for object emission." + >, + Option< + /*C++ var name=*/"cpu", + /*CLI arg name=*/"cpu", + /*type=*/"std::string", + /*default=*/"\"generic\"", + /*description=*/ + "LLVM CPU model fed to `createTargetMachine`. Defaults to " + "`generic`, which emits baseline code for the triple." + >, + Option< + /*C++ var name=*/"features", + /*CLI arg name=*/"features", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Comma-separated subtarget features fed to `createTargetMachine` " + "(each token must be `+feat` or `-feat`, e.g. `+crc,+aes,+sha2`). " + "Empty (default) lets the CPU choose its own feature set." + >, + Option< + /*C++ var name=*/"address", + /*CLI arg name=*/"address", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Executor's TCP `host:port` address for remote dispatch." + > + ]; +} + def ConvertRemoteToLLVMPass : Pass<"convert-remote-to-llvm", "mlir::ModuleOp"> { let summary = "Lower the `remote` dialect to direct calls into the Catalyst remote runtime."; let description = [{ diff --git a/mlir/lib/Remote/Transforms/CMakeLists.txt b/mlir/lib/Remote/Transforms/CMakeLists.txt index 3476c40708..235863da45 100644 --- a/mlir/lib/Remote/Transforms/CMakeLists.txt +++ b/mlir/lib/Remote/Transforms/CMakeLists.txt @@ -1,14 +1,22 @@ set(LIBRARY_NAME remote-transforms) +set(LLVM_LINK_COMPONENTS + AllTargetsAsmParsers + AllTargetsCodeGens +) + file(GLOB SRC + CrossCompileRemoteKernels.cpp RemoteToLLVM.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ${translation_libs} MLIRRemote MLIRCatalyst ) diff --git a/mlir/lib/Remote/Transforms/CrossCompileRemoteKernels.cpp b/mlir/lib/Remote/Transforms/CrossCompileRemoteKernels.cpp new file mode 100644 index 0000000000..cd2c8574c2 --- /dev/null +++ b/mlir/lib/Remote/Transforms/CrossCompileRemoteKernels.cpp @@ -0,0 +1,314 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Triple.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Export.h" + +#include "Catalyst/IR/CatalystOps.h" +#include "Driver/DefaultPipelines/DefaultPipelines.h" + +#include "Remote/IR/RemoteOps.h" +#include "Remote/Transforms/Passes.h" + +using namespace mlir; + +namespace catalyst { +namespace remote { + +#define GEN_PASS_DEF_CROSSCOMPILEREMOTEKERNELSPASS +#include "Remote/Transforms/Passes.h.inc" + +namespace { + +/// Sanitize the qnode so it can be compiled into a standalone module: +/// 1. Drop `llvm.linkage` and `qnode`. +/// 2. Promote to public visibility. +/// 3. Add `llvm.emit_c_interface` so it's callable from C as +/// `_mlir_ciface_` (and `_catalyst_pyface_` is provided by +/// downstream wrapping). +void sanitizeQNode(func::FuncOp fn) +{ + OpBuilder builder(fn.getContext()); + fn.setVisibility(SymbolTable::Visibility::Public); + fn->setAttr("llvm.emit_c_interface", builder.getUnitAttr()); + fn->setAttr("catalyst.remote_kernel", builder.getUnitAttr()); + fn->removeAttr("llvm.linkage"); + fn->removeAttr("qnode"); +} + +struct CrossCompileRemoteKernelsPass + : impl::CrossCompileRemoteKernelsPassBase { + using CrossCompileRemoteKernelsPassBase::CrossCompileRemoteKernelsPassBase; + + void runOnOperation() final + { + ModuleOp host = getOperation(); + + SmallVector qnodes = + llvm::filter_to_vector(host.getOps(), [&](func::FuncOp fn) { + return fn->hasAttr("qnode") && !fn->hasAttr("catalyst.remote_kernel"); + }); + + SmallVector calls; + host.walk([&](catalyst::CustomCallOp call) { + if (call.getCallTargetName() == "remote_call") { + calls.push_back(call); + } + }); + + if (qnodes.empty() && calls.empty()) { + return; + } + + if (!qnodes.empty()) { + if (workspace.empty()) { + host.emitError("Missing `workspace` option for remote kernel cross-compilation"); + return signalPassFailure(); + } + + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + } + + injectRemoteOpenIntoSetup(host); + + for (auto qnode : qnodes) { + if (failed(compileQNode(host, qnode))) { + return signalPassFailure(); + } + } + + rewriteLegacyLibCalls(calls); + } + + /// Rewrite legacy `catalyst.custom_call fn("remote_call")` ops into + /// typed `remote.call` ops carrying the executor address. + void rewriteLegacyLibCalls(ArrayRef libCalls) + { + MLIRContext *ctx = &getContext(); + auto addressAttr = StringAttr::get(ctx, address); + for (catalyst::CustomCallOp call : libCalls) { + auto symAttr = call->getAttrOfType("catalyst.remote_symbol"); + if (!symAttr) { + call->emitOpError("legacy remote_call missing `catalyst.remote_symbol` attribute"); + return signalPassFailure(); + } + OpBuilder b(call); + IntegerAttr numInputAttr = nullptr; + if (auto n = call.getNumberOriginalArg()) { + numInputAttr = b.getI32IntegerAttr(*n); + } + auto remoteCall = + remote::CallOp::create(b, call.getLoc(), call.getResultTypes(), call.getOperands(), + /*address=*/addressAttr, /*symbol=*/symAttr, + /*num_input_args=*/numInputAttr); + call.replaceAllUsesWith(remoteCall.getResults()); + call.erase(); + } + } + + /// Insert `remote.open` into the host's `setup` function exactly once. + void injectRemoteOpenIntoSetup(ModuleOp host) + { + auto setupFn = host.lookupSymbol("setup"); + if (!setupFn || setupFn.getBody().empty()) { + return; + } + Block &setupBody = setupFn.getBody().front(); + Operation *terminator = setupBody.getTerminator(); + if (!terminator) { + return; + } + OpBuilder b(terminator); + remote::OpenOp::create(b, setupFn.getLoc(), + /*address=*/StringAttr::get(&getContext(), address)); + } + + /// Insert `remote.send_binary` into the host's `setup` function for the + /// kernel we just produced. + void injectRemoteSendBinaryIntoSetup(ModuleOp host, StringAttr addressAttr, StringAttr pathAttr) + { + auto setupFn = host.lookupSymbol("setup"); + if (!setupFn || setupFn.getBody().empty()) { + return; + } + Block &setupBody = setupFn.getBody().front(); + Operation *terminator = setupBody.getTerminator(); + if (!terminator) { + return; + } + OpBuilder b(terminator); + remote::SendBinaryOp::create(b, setupFn.getLoc(), + /*address=*/addressAttr, + /*binary_path=*/pathAttr); + } + + /// Clone the qnode into a fresh module, sanitize it, run the LLVM-dialect + /// lowering stage, then translate to LLVM IR. + std::unique_ptr loweringQNode(func::FuncOp qnode, llvm::LLVMContext &llvmCtx) + { + MLIRContext *ctx = &getContext(); + StringRef qnodeName = qnode.getName(); + + OpBuilder builder(ctx); + auto clone = ModuleOp::create(builder.getUnknownLoc(), qnodeName); + auto qnodeClone = cast(qnode->clone()); + clone.getBody()->push_back(qnodeClone); + sanitizeQNode(qnodeClone); + + PassManager nested(ctx); + std::string passList = llvm::join(catalyst::driver::getLLVMDialectLoweringStage(), ","); + if (failed(parsePassPipeline(passList, nested))) { + qnode.emitError("Failed to parse LLVM-dialect lowering pipeline"); + clone->erase(); + return nullptr; + } + if (failed(nested.run(clone))) { + qnode.emitError("Lowering failed on cloned qnode"); + clone->erase(); + return nullptr; + } + + std::unique_ptr llvmModule = + translateModuleToLLVMIR(clone, llvmCtx, /*name=*/qnodeName); + if (!llvmModule) { + qnode.emitError("Failed to translate lowered qnode to LLVM IR"); + clone->erase(); + return nullptr; + } + clone->erase(); + return llvmModule; + } + + /// Emit the LLVM module to `/.o`. + std::string emitObjectFile(std::unique_ptr &&llvmModule, StringRef qnodeName, + StringRef targetTriple, StringRef cpuModel, StringRef featureList) + { + llvm::Triple parsedTriple{targetTriple}; + std::string err; + const llvm::Target *llvmTarget = llvm::TargetRegistry::lookupTarget(parsedTriple, err); + if (!llvmTarget) { + llvm::errs() << "Target triple '" << targetTriple + << "' not registered in this LLVM build: " << err << "\n"; + return ""; + } + llvm::TargetOptions opt; + std::unique_ptr targetMachine(llvmTarget->createTargetMachine( + parsedTriple, cpuModel, featureList, opt, llvm::Reloc::Model::PIC_)); + if (!targetMachine) { + llvm::errs() << "Could not create TargetMachine for triple '" << targetTriple + << "' cpu='" << cpuModel << "' features='" << featureList << "'\n"; + return ""; + } + + targetMachine->setOptLevel(llvm::CodeGenOptLevel::Aggressive); + llvmModule->setDataLayout(targetMachine->createDataLayout()); + llvmModule->setTargetTriple(parsedTriple); + + llvm::SmallString<128> p(workspace); + llvm::sys::path::append(p, qnodeName.str() + ".o"); + std::string objPath = std::string(p.str()); + + std::error_code errCode; + llvm::raw_fd_ostream dest(objPath, errCode, llvm::sys::fs::OF_None); + if (errCode) { + llvm::errs() << "Cannot open " << objPath << " for writing: " << errCode.message() + << "\n"; + return ""; + } + llvm::legacy::PassManager codegenPM; + if (targetMachine->addPassesToEmitFile(codegenPM, dest, nullptr, + llvm::CodeGenFileType::ObjectFile)) { + llvm::errs() << "TargetMachine cannot emit an object file for the requested target" + << "\n"; + return ""; + } + codegenPM.run(*llvmModule); + dest.flush(); + + return objPath; + } + + /// Compile a qnode to a `.o`, then replace every host-side `func.call` to + /// it with a `remote.launch` op and inject the matching `remote.send_binary` + /// into setup. + LogicalResult compileQNode(ModuleOp host, func::FuncOp qnode) + { + MLIRContext *ctx = &getContext(); + StringRef qnodeName = qnode.getName(); + + llvm::LLVMContext llvmCtx; + auto llvmModule = loweringQNode(qnode, llvmCtx); + if (!llvmModule) { + return failure(); + } + + std::string objPath = + emitObjectFile(std::move(llvmModule), qnodeName, target, cpu, features); + if (objPath.empty()) { + return failure(); + } + + SmallVector callsToReplace; + if (auto uses = SymbolTable::getSymbolUses(qnode.getNameAttr(), host)) { + for (const SymbolTable::SymbolUse &use : *uses) { + if (auto call = dyn_cast(use.getUser())) { + callsToReplace.push_back(call); + } + } + } + + auto pathAttr = StringAttr::get(ctx, objPath); + auto addressAttr = StringAttr::get(ctx, address); + auto calleeAttr = StringAttr::get(ctx, qnodeName); + + for (func::CallOp call : callsToReplace) { + OpBuilder builder(call); + auto launch = remote::LaunchOp::create(builder, call.getLoc(), call.getResultTypes(), + call.getOperands(), addressAttr, calleeAttr); + call.replaceAllUsesWith(launch.getResults()); + call.erase(); + } + + injectRemoteSendBinaryIntoSetup(host, addressAttr, pathAttr); + + qnode.erase(); + return success(); + } +}; + +} // namespace + +} // namespace remote +} // namespace catalyst