From 04ecd3eb7bb4db5509b4c653b5b40de804154da4 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Sun, 17 May 2026 03:40:19 -0400 Subject: [PATCH 1/6] Add kernel compilation pass for ORC JIT remote execution --- frontend/catalyst/compiler.py | 4 + frontend/catalyst/utils/CMakeLists.txt | 5 + frontend/catalyst/utils/wrapper.cpp | 19 + mlir/Makefile | 7 +- mlir/include/Catalyst/Transforms/Passes.td | 62 ++ .../DefaultPipelines/DefaultPipelines.h | 12 +- mlir/lib/Catalyst/Transforms/CMakeLists.txt | 8 + .../Transforms/CrossCompileRemoteKernels.cpp | 355 +++++++++ .../Catalyst/Transforms/catalyst_to_llvm.cpp | 411 ++++++++++ runtime/Makefile | 6 + runtime/include/Exception.hpp | 2 +- runtime/include/RemoteCAPI.h | 47 ++ runtime/include/RuntimeCAPI.h | 3 + runtime/lib/CMakeLists.txt | 4 + runtime/lib/capi/ExecutionContext.hpp | 10 + runtime/lib/capi/RuntimeCAPI.cpp | 2 + runtime/lib/remote/CMakeLists.txt | 57 ++ runtime/lib/remote/RemoteRuntime.cpp | 274 +++++++ runtime/lib/remote/RemoteSession.cpp | 747 ++++++++++++++++++ runtime/lib/remote/RemoteSession.hpp | 55 ++ 20 files changed, 2080 insertions(+), 10 deletions(-) create mode 100644 mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp create mode 100644 runtime/include/RemoteCAPI.h create mode 100644 runtime/lib/remote/CMakeLists.txt create mode 100644 runtime/lib/remote/RemoteRuntime.cpp create mode 100644 runtime/lib/remote/RemoteSession.cpp create mode 100644 runtime/lib/remote/RemoteSession.hpp diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 3c8c53ca98..cfc8e52759 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -176,6 +176,10 @@ def get_default_flags(options): # "-lrt_decoder", TODO: Re-enable when stringop-overflow warning on arm64 is resolved ] + rt_remote_so = "librt_remote" + file_extension + if os.path.isfile(os.path.join(rt_lib_path, rt_remote_so)): + default_flags.append("-lrt_remote") + # If OQD runtime capi is built, link to it as well # TODO: This is not ideal and should be replaced when the compiler is device aware if os.path.isfile(os.path.join(rt_lib_path, "librt_OQD_capi" + file_extension)): diff --git a/frontend/catalyst/utils/CMakeLists.txt b/frontend/catalyst/utils/CMakeLists.txt index 51ec6d2350..5566d4ebaf 100644 --- a/frontend/catalyst/utils/CMakeLists.txt +++ b/frontend/catalyst/utils/CMakeLists.txt @@ -38,6 +38,11 @@ nanobind_add_module(wrapper STABLE_ABI ${WRAPPER_SRC_FILES}) # Add the NumPy include directory to the library's include paths target_include_directories(wrapper PRIVATE ${Python_NumPy_INCLUDE_DIRS}) +# Catalyst runtime headers (for `Catalyst::Runtime::RuntimeException`). +target_include_directories(wrapper PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/../../../runtime/include +) + # Use suffix ".so" rather than ".abi3.so" for library file using Stable ABI # This is necessary for compatibility with setuptools build extensions set_target_properties(wrapper PROPERTIES SUFFIX ".so") diff --git a/frontend/catalyst/utils/wrapper.cpp b/frontend/catalyst/utils/wrapper.cpp index f9e5c29b8f..1f0ae3ba9c 100644 --- a/frontend/catalyst/utils/wrapper.cpp +++ b/frontend/catalyst/utils/wrapper.cpp @@ -22,6 +22,8 @@ #include "numpy/ndarrayobject.h" +#include "Exception.hpp" + namespace nb = nanobind; struct memref_beginning_t { @@ -233,6 +235,23 @@ nb::list wrap(nb::object func, nb::tuple py_args, nb::object result_desc, nb::ob NB_MODULE(wrapper, m) { m.doc() = "wrapper module"; + + nb::register_exception_translator([](const std::exception_ptr &p, void * /*payload*/) { + try { + std::rethrow_exception(p); + } + catch (const Catalyst::Runtime::RuntimeException &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + catch (const std::exception &e) { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + catch (...) { + PyErr_SetString(PyExc_RuntimeError, + "unknown C++ exception caught by wrapper translator"); + } + }); + // We have to annotate all the arguments to `wrap` to allow `result_desc` to be None // See https://nanobind.readthedocs.io/en/latest/functions.html#none-arguments m.def("wrap", &wrap, "A wrapper function.", nb::arg("func"), nb::arg("py_args"), diff --git a/mlir/Makefile b/mlir/Makefile index 3f19ff7d9b..becabca351 100644 --- a/mlir/Makefile +++ b/mlir/Makefile @@ -36,8 +36,9 @@ USE_SANITIZER_NAMES="" USE_SANITIZER_FLAGS="" endif -LLVM_PROJECTS ?= mlir -LLVM_TARGETS ?= check-mlir llvm-symbolizer +LLVM_PROJECTS ?= mlir;lld +LLVM_TARGETS ?= check-mlir llvm-symbolizer lld llvm-strip +LLVM_TARGETS_TO_BUILD ?= host;AArch64;X86 .PHONY: help help: @@ -75,7 +76,7 @@ llvm: cmake -G Ninja -S llvm-project/llvm -B $(LLVM_BUILD_DIR) \ -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ -DLLVM_BUILD_EXAMPLES=OFF \ - -DLLVM_TARGETS_TO_BUILD="host" \ + -DLLVM_TARGETS_TO_BUILD="$(LLVM_TARGETS_TO_BUILD)" \ -DLLVM_ENABLE_PROJECTS="$(LLVM_PROJECTS)" \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index c5f87f0d7c..d87b1c4452 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -143,6 +143,68 @@ def ApplyTransformSequencePass : Pass<"apply-transform-sequence"> { let summary = "Apply the passes scheduled with the transform dialect."; } +def CrossCompileRemoteKernelsPass : Pass<"cross-compile-remote-kernels", "mlir::ModuleOp"> { + let summary = "Cross-compile each `qnode` to a `.o` and stash the path on the host call sites."; + let description = [{ + For every `func.func` in the host module carrying the `qnode` attribute, this pass: + 1. Clones the func into a standalone module. + 2. Sanitizes it for lowering to LLVM IR and emits the `.o` to `/.o`. + 3. Marks every `func.call` targeting the qnode with `catalyst.remote_kernel_path` + holding the absolute path of the produced `.o`. + }]; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::func::FuncDialect", + ]; + + let options = [ + Option< + /*C++ var name=*/"workspace", + /*CLI arg name=*/"workspace", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Filesystem directory to write cross-compiled `.o` files into." + >, + Option< + /*C++ var name=*/"target", + /*CLI arg name=*/"target", + /*type=*/"std::string", + /*default=*/"\"x86_64\"", + /*description=*/ + "LLVM target triple used for object emission." + >, + Option< + /*C++ var name=*/"cpu", + /*CLI arg name=*/"cpu", + /*type=*/"std::string", + /*default=*/"\"generic\"", + /*description=*/ + "LLVM CPU model fed to `createTargetMachine`. Defaults to " + "`generic`, which emits baseline code for the triple." + >, + Option< + /*C++ var name=*/"features", + /*CLI arg name=*/"features", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Comma-separated subtarget features fed to `createTargetMachine` " + "(each token must be `+feat` or `-feat`, e.g. `+crc,+aes,+sha2`). " + "Empty (default) lets the CPU choose its own feature set." + >, + Option< + /*C++ var name=*/"address", + /*CLI arg name=*/"address", + /*type=*/"std::string", + /*default=*/"\"\"", + /*description=*/ + "Executor's TCP `host:port` address for remote dispatch." + > + ]; +} + def InlineNestedModulePass : Pass<"inline-nested-module"> { let summary = "Inline nested modules with qnode attribute."; diff --git a/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h b/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h index 33aec77e6d..088af5b3c8 100644 --- a/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h +++ b/mlir/include/Driver/DefaultPipelines/DefaultPipelines.h @@ -163,7 +163,7 @@ const PipelineList pipelineList{ "register-inactive-callback"}}}; // clang-format on -PipelineNames getPipelineNames() +inline PipelineNames getPipelineNames() { static std::vector names = std::accumulate(driver::pipelineList.begin(), driver::pipelineList.end(), @@ -174,7 +174,7 @@ PipelineNames getPipelineNames() return names; } -PassNames getQuantumCompilationStage(bool disableAssertion = true) +inline PassNames getQuantumCompilationStage(bool disableAssertion = true) { PassNames ret; std::copy_if(pipelineList[0].passNames.begin(), pipelineList[0].passNames.end(), @@ -184,11 +184,11 @@ PassNames getQuantumCompilationStage(bool disableAssertion = true) return ret; } -PassNames getHLOLoweringStage() { return pipelineList[1].passNames; } +inline PassNames getHLOLoweringStage() { return pipelineList[1].passNames; } -PassNames getGradientLoweringStage() { return pipelineList[2].passNames; } +inline PassNames getGradientLoweringStage() { return pipelineList[2].passNames; } -PassNames getBufferizationStage(bool asyncQNodes = false) +inline PassNames getBufferizationStage(bool asyncQNodes = false) { const std::string bufferizationOptions = std::string("{bufferize-function-boundaries ") + "allow-return-allocs-from-loops " + @@ -206,7 +206,7 @@ PassNames getBufferizationStage(bool asyncQNodes = false) return ret; } -PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false) +inline PassNames getLLVMDialectLoweringStage(bool asyncQNodes = false) { PassNames ret; std::copy_if(pipelineList[4].passNames.begin(), pipelineList[4].passNames.end(), diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index bb794d3b4c..a7a5211dcc 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ file(GLOB SRC BufferDeallocation.cpp BufferizableOpInterfaceImpl.cpp catalyst_to_llvm.cpp + CrossCompileRemoteKernels.cpp DetectQNodes.cpp DetensorizeFunctionBoundaryPass.cpp DetensorizeSCFPass.cpp @@ -29,11 +30,18 @@ file(GLOB SRC TBAATagsPass.cpp ) +set(LLVM_LINK_COMPONENTS + AllTargetsAsmParsers + AllTargetsCodeGens +) + get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) set(LIBS ${dialect_libs} ${conversion_libs} + ${translation_libs} catalyst-analysis ) diff --git a/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp b/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp new file mode 100644 index 0000000000..9eab453a01 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp @@ -0,0 +1,355 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Triple.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Export.h" + +#include "Catalyst/IR/CatalystOps.h" +#include "Driver/DefaultPipelines/DefaultPipelines.h" + +using namespace mlir; + +namespace catalyst { + +#define GEN_PASS_DECL_CROSSCOMPILEREMOTEKERNELSPASS +#define GEN_PASS_DEF_CROSSCOMPILEREMOTEKERNELSPASS +#include "Catalyst/Transforms/Passes.h.inc" + +namespace { + +// Sanitize the qnode for compiling the qnode into a standalone module. + +/** + * @brief Sanitize the qnode for compiling the qnode into a standalone module. + * The conversion takes place in the following steps: + * 1. Strip `llvm.linkage` + * 2. Set it to public visibility. + * 3. Add `llvm.emit_c_interface` to make the qnode callable from C. + * We then further wrap it in a `builtin.module` and emit the `.o` file. + * + * @example + * Before sanitization: + * ```mlir + * func.func @qnode_0(%arg0: memref<1xi64>) -> memref + * attributes {llvm.linkage = #llvm.linkage, qnode} { + * return %0 : memref + * } + * ``` + * After sanitization: + * ```mlir + * func.func @qnode_0(%arg0: memref<1xi64>) -> memref + * attributes {llvm.emit_c_interface, catalyst.remote_kernel} { + * return %0 : memref + * } + * ``` + * @param fn The qnode to sanitize. + */ +void sanitizeQNode(func::FuncOp fn) +{ + OpBuilder builder(fn.getContext()); + fn.setVisibility(SymbolTable::Visibility::Public); + fn->setAttr("llvm.emit_c_interface", builder.getUnitAttr()); + fn->setAttr("catalyst.remote_kernel", builder.getUnitAttr()); + fn->removeAttr("llvm.linkage"); + fn->removeAttr("qnode"); +} + +struct CrossCompileRemoteKernelsPass + : impl::CrossCompileRemoteKernelsPassBase { + using CrossCompileRemoteKernelsPassBase::CrossCompileRemoteKernelsPassBase; + + void runOnOperation() final + { + ModuleOp host = getOperation(); + + SmallVector qnodes = + llvm::filter_to_vector(host.getOps(), [&](func::FuncOp fn) { + return fn->hasAttr("qnode") && !fn->hasAttr("catalyst.remote_kernel"); + }); + + if (qnodes.empty()) { + return; + } + + if (workspace.empty()) { + host.emitError("Missing `workspace` option for remote kernel cross-compilation"); + return signalPassFailure(); + } + + // For cross-compilation, we need to initialize the LLVM target registry. + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + + injectRemoteOpenIntoSetup(host); + + for (auto qnode : qnodes) { + if (failed(compileQNode(host, qnode))) { + return signalPassFailure(); + } + } + + attachAddressToPluginCalls(host); + } + + // For each `catalyst.custom_call fn("remote_lib_call")`, + // attach the executor's `catalyst.remote_address` for remote library calls. + void attachAddressToPluginCalls(ModuleOp host) + { + auto addressAttr = StringAttr::get(&getContext(), address); + host.walk([&](catalyst::CustomCallOp call) { + if (call.getCallTargetName() == "remote_lib_call") { + call->setAttr("catalyst.remote_address", addressAttr); + } + }); + } + + // Insert `__catalyst__remote__open(addr)` into the host's `setup` function exactly once. + void injectRemoteOpenIntoSetup(ModuleOp host) + { + auto setupFn = host.lookupSymbol("setup"); + if (!setupFn || setupFn.getBody().empty()) { + return; + } + Block &setupBody = setupFn.getBody().front(); + Operation *terminator = setupBody.getTerminator(); + if (!terminator) { + return; + } + OpBuilder b(terminator); + Location loc = setupFn.getLoc(); + auto openOp = catalyst::CustomCallOp::create( + b, loc, /*resultTypes=*/TypeRange{}, /*inputs=*/ValueRange{}, + /*call_target_name=*/"remote_open", /*number_original_arg=*/nullptr); + openOp->setAttr("catalyst.remote_address", StringAttr::get(&getContext(), address)); + } + + // Insert `__catalyst__remote__send_binary(addr, path)` into the host's `setup` function. + void injectRemoteSendBinaryIntoSetup(ModuleOp host, StringAttr addressAttr, StringAttr pathAttr, + StringAttr calleeAttr) + { + auto setupFn = host.lookupSymbol("setup"); + if (!setupFn || setupFn.getBody().empty()) { + return; + } + Block &setupBody = setupFn.getBody().front(); + Operation *terminator = setupBody.getTerminator(); + if (!terminator) { + return; + } + OpBuilder b(terminator); + Location loc = setupFn.getLoc(); + auto sendOp = catalyst::CustomCallOp::create(b, loc, /*resultTypes=*/TypeRange{}, + /*inputs=*/ValueRange{}, + /*call_target_name=*/"remote_send_binary", + /*number_original_arg=*/nullptr); + sendOp->setAttr("catalyst.remote_address", addressAttr); + sendOp->setAttr("catalyst.remote_kernel_path", pathAttr); + sendOp->setAttr("catalyst.remote_kernel_callee", calleeAttr); + } + + /** + * @brief Lower the qnode to LLVM IR. + * This function takes the qnode and lowers it to LLVM IR for cross-compilation. + * + * @param qnode The qnode to lower. + * @param llvmCtx Context that owns the produced module; must outlive the returned Module. + * @return std::unique_ptr The LLVM module containing the lowered qnode. + */ + std::unique_ptr loweringQNode(func::FuncOp qnode, llvm::LLVMContext &llvmCtx) + { + MLIRContext *ctx = &getContext(); + StringRef qnodeName = qnode.getName(); + + // Step 1: Clone the qnode into a fresh standalone module. + OpBuilder builder(ctx); + auto clone = ModuleOp::create(builder.getUnknownLoc(), qnodeName); + auto qnodeClone = cast(qnode->clone()); + clone.getBody()->push_back(qnodeClone); + sanitizeQNode(qnodeClone); + + // Step 2: Run the LLVM-dialect lowering stage on the clone. + PassManager nested(ctx); + std::string passList = llvm::join(catalyst::driver::getLLVMDialectLoweringStage(), ","); + if (failed(parsePassPipeline(passList, nested))) { + qnode.emitError("Failed to parse LLVM-dialect lowering pipeline"); + clone->erase(); + return nullptr; + } + if (failed(nested.run(clone))) { + qnode.emitError("Lowering failed on cloned qnode"); + clone->erase(); + return nullptr; + } + + // Step 3: Translate the lowered clone to LLVM IR. + std::unique_ptr llvmModule = + translateModuleToLLVMIR(clone, llvmCtx, /*name=*/qnodeName); + if (!llvmModule) { + qnode.emitError("Failed to translate lowered qnode to LLVM IR"); + clone->erase(); + return nullptr; + } + clone->erase(); + return llvmModule; + } + + /** + * @brief Emit the LLVM module to an object file. + * This function takes the LLVM module and emits it to an object file. + * + * @param llvmModule The LLVM module to emit. + * @param qnodeName The name of the qnode. + * @param target The target triple to emit the object file for. + * @param cpu The CPU model. + * @param features Comma-separated `+feat`/`-feat` tokens, or empty to let + * the CPU pick its own feature set. + * @return std::string The path to the emitted object file. + */ + std::string emitObjectFile(std::unique_ptr &&llvmModule, StringRef qnodeName, + StringRef target, StringRef cpu, StringRef features) + { + std::string objPath; + + // Build a TargetMachine + object emission. + llvm::Triple parsedTriple{target}; + std::string err; + const llvm::Target *llvmTarget = llvm::TargetRegistry::lookupTarget(parsedTriple, err); + if (!llvmTarget) { + llvm::errs() << "Target triple '" << target + << "' not registered in this LLVM build: " << err << "\n"; + return ""; + } + llvm::TargetOptions opt; + std::unique_ptr targetMachine(llvmTarget->createTargetMachine( + parsedTriple, cpu, features, opt, llvm::Reloc::Model::PIC_)); + if (!targetMachine) { + llvm::errs() << "Could not create TargetMachine for triple '" << target + << "' cpu='" << cpu << "' features='" << features << "'\n"; + return ""; + } + + targetMachine->setOptLevel(llvm::CodeGenOptLevel::Aggressive); // -O3 + llvmModule->setDataLayout(targetMachine->createDataLayout()); + llvmModule->setTargetTriple(parsedTriple); + + llvm::SmallString<128> p(workspace); + llvm::sys::path::append(p, qnodeName.str() + ".o"); + objPath = std::string(p.str()); + + std::error_code errCode; + llvm::raw_fd_ostream dest(objPath, errCode, llvm::sys::fs::OF_None); + if (errCode) { + llvm::errs() << "Cannot open " << objPath << " for writing: " << errCode.message() + << "\n"; + return ""; + } + llvm::legacy::PassManager codegenPM; + if (targetMachine->addPassesToEmitFile(codegenPM, dest, nullptr, + llvm::CodeGenFileType::ObjectFile)) { + llvm::errs() << "TargetMachine cannot emit an object file for the requested target" + << "\n"; + return ""; + } + codegenPM.run(*llvmModule); + dest.flush(); + + return objPath; + } + + /** + * @brief Compile the qnode as a standalone module, emit the object file, and replace the + * host-side calls with custom calls. + * + * @param host The host module containing the qnode. + * @param qnode The qnode to compile. + * @return LogicalResult Success if the qnode is compiled successfully, failure otherwise. + */ + LogicalResult compileQNode(ModuleOp host, func::FuncOp qnode) + { + MLIRContext *ctx = &getContext(); + StringRef qnodeName = qnode.getName(); + + llvm::LLVMContext llvmCtx; + auto llvmModule = loweringQNode(qnode, llvmCtx); + if (!llvmModule) { + return failure(); + } + + // Build a TargetMachine + object emission. + std::string objPath = + emitObjectFile(std::move(llvmModule), qnodeName, target, cpu, features); + if (objPath.empty()) { + return failure(); + } + + // Replace every host-side `func.call @(...)` with a + // `catalyst.custom_call fn("remote_call")(...)` + SmallVector callsToReplace; + if (auto uses = SymbolTable::getSymbolUses(qnode.getNameAttr(), host)) { + for (const SymbolTable::SymbolUse &use : *uses) { + if (auto call = dyn_cast(use.getUser())) { + callsToReplace.push_back(call); + } + } + } + + auto pathAttr = StringAttr::get(ctx, objPath); + auto addressAttr = StringAttr::get(ctx, address); + auto calleeAttr = StringAttr::get(ctx, qnodeName); + for (func::CallOp call : callsToReplace) { + OpBuilder builder(call); + + auto custom = catalyst::CustomCallOp::create(builder, call.getLoc(), + /*resultTypes=*/call.getResultTypes(), + call.getOperands(), + /*call_target_name=*/"remote_call", + /*number_original_arg=*/nullptr); + custom->setAttr("catalyst.remote_kernel_path", pathAttr); + custom->setAttr("catalyst.remote_address", addressAttr); + custom->setAttr("catalyst.remote_kernel_callee", calleeAttr); + call.replaceAllUsesWith(custom.getResults()); + call.erase(); + } + + // Inject binary send into setup function. + injectRemoteSendBinaryIntoSetup(host, addressAttr, pathAttr, calleeAttr); + + qnode.erase(); + return success(); + } +}; + +} // namespace + +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index 783a8a1e04..626b5f4a91 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -53,6 +53,81 @@ Value getGlobalString(Location loc, OpBuilder &rewriter, StringRef key, StringRe ArrayRef{0, 0}, LLVM::GEPNoWrapFlags::inbounds); } +// Get or create a `internal constant !llvm.array` global. +// And return a `!llvm.ptr` to its first element. +// +// llvm.mlir.global internal constant @(dense<[v0, v1, ...]> : tensor) +// : !llvm.array +// +// Call site: +// +// %addr = llvm.mlir.addressof @ : !llvm.ptr +// %ptr = llvm.getelementptr inbounds %addr[0, 0] +// : (!llvm.ptr) -> !llvm.ptr, !llvm.array +// +Value getGlobalI64Array(Location loc, OpBuilder &rewriter, StringRef key, ArrayRef values, + ModuleOp mod) +{ + Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + // Skip and return a null pointer if the array is empty. + if (values.empty()) { + return LLVM::ZeroOp::create(rewriter, loc, ptrTy); + } + Type i64Ty = rewriter.getI64Type(); + auto arrTy = LLVM::LLVMArrayType::get(i64Ty, values.size()); + LLVM::GlobalOp glb = mod.lookupSymbol(key); + // Create a new global if it doesn't exist. + if (!glb) { + auto tensorTy = RankedTensorType::get({static_cast(values.size())}, i64Ty); + auto valuesAttr = DenseElementsAttr::get(tensorTy, values); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(mod.getBody()); + glb = LLVM::GlobalOp::create(rewriter, loc, arrTy, /*isConstant=*/true, + LLVM::Linkage::Internal, key, valuesAttr); + } + return LLVM::GEPOp::create(rewriter, loc, ptrTy, arrTy, + LLVM::AddressOfOp::create(rewriter, loc, glb), + ArrayRef{0, 0}, LLVM::GEPNoWrapFlags::inbounds); +} + +// Allocate a stack buffer of `!llvm.array`. +// +// Outputs: +// +// %slot = llvm.alloca ... : !llvm.array +// %a0 = llvm.undef : !llvm.array +// %a1 = llvm.insertvalue %ptr0, %a0[0] : !llvm.array +// %a2 = llvm.insertvalue %ptr1, %a1[1] : !llvm.array +// ... +// llvm.store %aN, %slot : !llvm.array, !llvm.ptr +// +Value buildStackPtrArray(Location loc, RewriterBase &rewriter, ArrayRef ptrs) +{ + Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext()); + if (ptrs.empty()) { + return LLVM::ZeroOp::create(rewriter, loc, ptrTy); + } + auto arrTy = LLVM::LLVMArrayType::get(ptrTy, ptrs.size()); + Value alloca = getStaticAlloca(loc, rewriter, arrTy, 1); + Value arr = LLVM::UndefOp::create(rewriter, loc, arrTy); + for (auto [i, p] : llvm::enumerate(ptrs)) { + arr = LLVM::InsertValueOp::create(rewriter, loc, arr, p, SmallVector{(int64_t)i}); + } + LLVM::StoreOp::create(rewriter, loc, arr, alloca); + return alloca; +} + +// Calculate the byte size of one memref element. +// For complex, the size is 2 * sizeof(T). 1 for real, 1 for imaginary. +int64_t memrefElemSizeBytes(MemRefType ty) +{ + Type elem = ty.getElementType(); + if (auto cplx = dyn_cast(elem)) { + return 2 * ((cplx.getElementType().getIntOrFloatBitWidth() + 7) / 8); + } + return (elem.getIntOrFloatBitWidth() + 7) / 8; +} + enum NumericType : int8_t { index = 0, i1, @@ -324,6 +399,12 @@ struct CustomCallOpPattern : public OpConversionPattern { LogicalResult matchAndRewrite(CustomCallOp op, CustomCallOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Remote-dispatch custom_calls are lowered by their own dedicated patterns below. + StringRef name = op.getCallTargetName(); + if (name == "remote_call" || name == "remote_lib_call" || name == "remote_open" || + name == "remote_send_binary") { + return failure(); + } MLIRContext *ctx = op.getContext(); Location loc = op.getLoc(); // Create function @@ -435,6 +516,332 @@ struct CustomCallOpPattern : public OpConversionPattern { } }; +// Rewrite `catalyst.custom_call fn("remote_call") -> ...` +// to three runtime calls: +// +// __catalyst__remote__open(addr) +// __catalyst__remote__send_binary(addr,p) +// __catalyst__remote__launch(addr, "_catalyst_pyface_", +// num_in, in_descs, in_ranks, in_sizes, +// num_out, out_descs, out_ranks, out_sizes); +// +struct RemoteCustomCallOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CustomCallOp op, CustomCallOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + if (op.getCallTargetName() != "remote_call") { + return failure(); + } + + auto addrAttr = op->getAttrOfType("catalyst.remote_address"); + auto calleeAttr = op->getAttrOfType("catalyst.remote_kernel_callee"); + if (!addrAttr) { + llvm::errs() << "remote_call custom_call is missing `catalyst.remote_address`\n"; + return failure(); + } + if (!calleeAttr) { + llvm::errs() << "remote_call custom_call is missing `catalyst.remote_kernel_callee`\n"; + return failure(); + } + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type ptrTy = LLVM::LLVMPointerType::get(ctx); + Type i64Ty = rewriter.getI64Type(); + Type voidTy = LLVM::LLVMVoidType::get(ctx); + ModuleOp mod = op->getParentOfType(); + + // Declare extern runtime entry points. + Type launchSig = LLVM::LLVMFunctionType::get( + voidTy, {ptrTy, ptrTy, i64Ty, ptrTy, ptrTy, ptrTy, i64Ty, ptrTy, ptrTy, ptrTy}); + LLVM::LLVMFuncOp launchFn = catalyst::ensureFunctionDeclaration( + rewriter, op, "__catalyst__remote__launch", launchSig); + + // We need the address string globally so the launch knows which session to dispatch to. + std::string callee = calleeAttr.getValue().str(); + Value addrPtr = getGlobalString(loc, rewriter, "remote_addr_" + callee, + addrAttr.getValue().str() + '\0', mod); + + // Get a global string for the symbol name "_catalyst_pyface_" + std::string symbolName = "_catalyst_pyface_" + callee; + std::string symbolKey = "remote_sym_" + callee; + Value symbolPtr = getGlobalString(loc, rewriter, symbolKey, symbolName + '\0', mod); + + // Spill input descriptor structs to stack allocas + SmallVector inputDescPtrs; + SmallVector inputRanks, inputElemSizes; + for (auto [origInput, llvmInput] : llvm::zip(op.getOperands(), adaptor.getOperands())) { + auto memrefTy = cast(origInput.getType()); + inputRanks.push_back(memrefTy.getRank()); + inputElemSizes.push_back(memrefElemSizeBytes(memrefTy)); + Value alloca = getStaticAlloca(loc, rewriter, llvmInput.getType(), 1); + LLVM::StoreOp::create(rewriter, loc, llvmInput, alloca); + inputDescPtrs.push_back(alloca); + } + + // Allocate stack buffers holding input/output descriptor pointers. + SmallVector outputDescPtrs; + SmallVector outputRanks, outputElemSizes; + for (Type resultTy : op.getResultTypes()) { + auto memrefTy = cast(resultTy); + outputRanks.push_back(memrefTy.getRank()); + outputElemSizes.push_back(memrefElemSizeBytes(memrefTy)); + Type llvmDescTy = getTypeConverter()->convertType(resultTy); + Value alloca = getStaticAlloca(loc, rewriter, llvmDescTy, 1); + outputDescPtrs.push_back(alloca); + } + + Value inputDescsArr = buildStackPtrArray(loc, rewriter, inputDescPtrs); + Value outputDescsArr = buildStackPtrArray(loc, rewriter, outputDescPtrs); + + // Get global arrays for ranks / elem-sizes. + Value inputRanksArr = + getGlobalI64Array(loc, rewriter, "remote_in_ranks_" + callee, inputRanks, mod); + Value inputSizesArr = + getGlobalI64Array(loc, rewriter, "remote_in_sizes_" + callee, inputElemSizes, mod); + Value outputRanksArr = + getGlobalI64Array(loc, rewriter, "remote_out_ranks_" + callee, outputRanks, mod); + Value outputSizesArr = + getGlobalI64Array(loc, rewriter, "remote_out_sizes_" + callee, outputElemSizes, mod); + + Value numInputs = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputDescPtrs.size())); + Value numOutputs = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(outputDescPtrs.size())); + + LLVM::CallOp::create(rewriter, loc, launchFn, + ValueRange{addrPtr, symbolPtr, numInputs, inputDescsArr, inputRanksArr, + inputSizesArr, numOutputs, outputDescsArr, outputRanksArr, + outputSizesArr}); + + // Load the runtime-filled output descriptors and replace the op with them. + SmallVector results; + for (auto [descPtr, resultTy] : llvm::zip(outputDescPtrs, op.getResultTypes())) { + Type llvmDescTy = getTypeConverter()->convertType(resultTy); + Value loaded = LLVM::LoadOp::create(rewriter, loc, llvmDescTy, descPtr); + results.push_back(loaded); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Rewrite the `catalyst.custom_call fn("remote_open")` op to `__catalyst__remote__open(addr)`. +struct RemoteOpenOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CustomCallOp op, CustomCallOpAdaptor, + ConversionPatternRewriter &rewriter) const override + { + if (op.getCallTargetName() != "remote_open") { + return failure(); + } + auto addrAttr = op->getAttrOfType("catalyst.remote_address"); + if (!addrAttr) { + return op->emitOpError("remote_open call is missing `catalyst.remote_address`"); + } + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type ptrTy = LLVM::LLVMPointerType::get(ctx); + Type i64Ty = rewriter.getI64Type(); + ModuleOp mod = op->getParentOfType(); + + Type openSig = LLVM::LLVMFunctionType::get(i64Ty, {ptrTy}); + LLVM::LLVMFuncOp openFn = catalyst::ensureFunctionDeclaration( + rewriter, op, "__catalyst__remote__open", openSig); + + Value addrPtr = getGlobalString(loc, rewriter, "remote_setup_addr", + addrAttr.getValue().str() + '\0', mod); + + LLVM::CallOp::create(rewriter, loc, openFn, ValueRange{addrPtr}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Rewrite the `catalyst.custom_call fn("remote_send_binary")` op to +// `__catalyst__remote__send_binary(addr, path)`. +struct RemoteSendBinaryOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CustomCallOp op, CustomCallOpAdaptor, + ConversionPatternRewriter &rewriter) const override + { + if (op.getCallTargetName() != "remote_send_binary") { + return failure(); + } + auto addrAttr = op->getAttrOfType("catalyst.remote_address"); + auto pathAttr = op->getAttrOfType("catalyst.remote_kernel_path"); + auto calleeAttr = op->getAttrOfType("catalyst.remote_kernel_callee"); + if (!addrAttr) { + return op->emitOpError("remote_send_binary call is missing `catalyst.remote_address`"); + } + if (!pathAttr) { + return op->emitOpError( + "remote_send_binary call is missing `catalyst.remote_kernel_path`"); + } + if (!calleeAttr) { + return op->emitOpError( + "remote_send_binary call is missing `catalyst.remote_kernel_callee`"); + } + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type ptrTy = LLVM::LLVMPointerType::get(ctx); + Type i32Ty = rewriter.getI32Type(); + Type i64Ty = rewriter.getI64Type(); + ModuleOp mod = op->getParentOfType(); + + Type sendBinSig = LLVM::LLVMFunctionType::get(i64Ty, {ptrTy, ptrTy, i32Ty}); + LLVM::LLVMFuncOp sendBinFn = catalyst::ensureFunctionDeclaration( + rewriter, op, "__catalyst__remote__send_binary", sendBinSig); + + std::string callee = calleeAttr.getValue().str(); + Value addrPtr = getGlobalString(loc, rewriter, "remote_addr_" + callee, + addrAttr.getValue().str() + '\0', mod); + Value pathPtr = getGlobalString(loc, rewriter, "remote_path_" + callee, + pathAttr.getValue().str() + '\0', mod); + + // TODO: Hardcoded format tag for now. (0 as object) + Value formatTag = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32IntegerAttr(0)); + + LLVM::CallOp::create(rewriter, loc, sendBinFn, ValueRange{addrPtr, pathPtr, formatTag}); + rewriter.eraseOp(op); + return success(); + } +}; + +// Rewrite the `catalyst.custom_call fn("remote_lib_call")` op to +// `__catalyst__remote__call_wrapper(addr, sym, args_buf, args_size, &out, &out_size)`. +struct RemoteLibCallOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CustomCallOp op, CustomCallOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override + { + if (op.getCallTargetName() != "remote_lib_call") { + return failure(); + } + + auto addrAttr = op->getAttrOfType("catalyst.remote_address"); + auto symAttr = op->getAttrOfType("catalyst.remote_lib_symbol"); + if (!addrAttr) { + return op->emitOpError("remote_lib_call is missing `catalyst.remote_address`"); + } + if (!symAttr) { + return op->emitOpError("remote_lib_call is missing `catalyst.remote_lib_symbol`"); + } + std::string sym = symAttr.getValue().str(); + + Location loc = op.getLoc(); + MLIRContext *ctx = rewriter.getContext(); + Type ptrTy = LLVM::LLVMPointerType::get(ctx); + Type i32Ty = rewriter.getI32Type(); + Type i64Ty = rewriter.getI64Type(); + Type voidTy = LLVM::LLVMVoidType::get(ctx); + Type i8Ty = rewriter.getI8Type(); + ModuleOp mod = op->getParentOfType(); + + // Declare extern runtime entry points. + Type callSig = + LLVM::LLVMFunctionType::get(i32Ty, {ptrTy, ptrTy, ptrTy, i64Ty, ptrTy, ptrTy}); + LLVM::LLVMFuncOp callFn = catalyst::ensureFunctionDeclaration( + rewriter, op, "__catalyst__remote__call_wrapper", callSig); + Type freeSig = LLVM::LLVMFunctionType::get(voidTy, {ptrTy}); + LLVM::LLVMFuncOp freeFn = catalyst::ensureFunctionDeclaration( + rewriter, op, "__catalyst__remote__free_result", freeSig); + + SmallVector offsets; + int64_t totalSize = 0; + for (Type ty : op.getOperandTypes()) { + int64_t n = primitiveByteSize(ty); + if (n < 0) { + return op->emitOpError("unsupported arg type for remote_lib_call: ") + << ty << " (supports int/float/index/complex only)"; + } + offsets.push_back(totalSize); + totalSize += n; + } + + Type bufTy = LLVM::LLVMArrayType::get(i8Ty, totalSize > 0 ? totalSize : 1); + + // Symbols + Value addrPtr = getGlobalString(loc, rewriter, "remote_lib_addr_" + sym, + addrAttr.getValue().str() + '\0', mod); + Value symPtr = getGlobalString(loc, rewriter, "remote_lib_sym_" + sym, sym + '\0', mod); + + // Alloca args buffer + store each arg. + Value argsBuf = getStaticAlloca(loc, rewriter, bufTy, 1); + for (auto [llvmVal, off] : llvm::zip(adaptor.getOperands(), offsets)) { + Value slot = LLVM::GEPOp::create(rewriter, loc, ptrTy, bufTy, argsBuf, + ArrayRef{0, static_cast(off)}, + LLVM::GEPNoWrapFlags::inbounds); + LLVM::StoreOp::create(rewriter, loc, llvmVal, slot); + } + Value argsSize = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(totalSize)); + + // Alloca result buffer + size. + Value outBufSlot = getStaticAlloca(loc, rewriter, ptrTy, 1); + Value outSizeSlot = getStaticAlloca(loc, rewriter, i64Ty, 1); + + // Call the runtime. + LLVM::CallOp::create( + rewriter, loc, callFn, + ValueRange{addrPtr, symPtr, argsBuf, argsSize, outBufSlot, outSizeSlot}); + + // Decode return value (if any). + SmallVector returns; + Value outBuf; + if (!op.getResultTypes().empty()) { + if (op.getResultTypes().size() != 1) { + return op->emitOpError("remote_lib_call supports at most one result"); + } + Type retTy = op.getResultTypes().front(); + if (primitiveByteSize(retTy) < 0) { + return op->emitOpError("unsupported return type for remote_lib_call: ") << retTy; + } + Type retLLVMTy = getTypeConverter()->convertType(retTy); + outBuf = LLVM::LoadOp::create(rewriter, loc, ptrTy, outBufSlot); + Value rv = LLVM::LoadOp::create(rewriter, loc, retLLVMTy, outBuf); + returns.push_back(rv); + } + else { + outBuf = LLVM::LoadOp::create(rewriter, loc, ptrTy, outBufSlot); + } + + // Release the runtime-allocated result buffer. + LLVM::CallOp::create(rewriter, loc, freeFn, ValueRange{outBuf}); + + rewriter.replaceOp(op, returns); + return success(); + } + + private: + // Supported scalar byte sizes. Returns -1 for unsupported types. + static int64_t primitiveByteSize(Type ty) + { + if (auto i = dyn_cast(ty)) { + return (i.getWidth() + 7) / 8; + } + if (auto f = dyn_cast(ty)) { + return (f.getWidth() + 7) / 8; + } + if (isa(ty)) { + return 8; + } + if (auto c = dyn_cast(ty)) { + int64_t inner = primitiveByteSize(c.getElementType()); + return inner < 0 ? -1 : 2 * inner; + } + return -1; + } +}; + struct DefineCallbackOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -587,6 +994,10 @@ struct CatalystConversionPass : impl::CatalystConversionPassBase(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); diff --git a/runtime/Makefile b/runtime/Makefile index 84c9ba328e..ec79fc9c6f 100644 --- a/runtime/Makefile +++ b/runtime/Makefile @@ -13,6 +13,7 @@ CODE_COVERAGE ?= OFF BUILD_TYPE ?= RelWithDebInfo ENABLE_OPENQASM ?= ON ENABLE_OQD ?= OFF +ENABLE_REMOTE ?= OFF ENABLE_ASAN ?= OFF STRICT_WARNINGS ?= ON LLVM_DIR ?= $(MK_DIR)/../mlir/llvm-project/ @@ -53,6 +54,10 @@ ifeq ($(ENABLE_OQD), ON) TEST_TARGETS += runner_tests_oqd endif +ifeq ($(ENABLE_REMOTE), ON) + BUILD_TARGETS += rt_remote +endif + .PHONY: help help: @echo "Please use \`make ' where is one of" @@ -78,6 +83,7 @@ configure: -DCMAKE_CXX_COMPILER_LAUNCHER=$(COMPILER_LAUNCHER) \ -DENABLE_OPENQASM=$(ENABLE_OPENQASM) \ -DENABLE_OQD=$(ENABLE_OQD) \ + -DENABLE_REMOTE=$(ENABLE_REMOTE) \ -DENABLE_CODE_COVERAGE=$(CODE_COVERAGE) \ -DPython_EXECUTABLE=$(PYTHON) \ -DENABLE_ADDRESS_SANITIZER=$(ENABLE_ASAN) \ diff --git a/runtime/include/Exception.hpp b/runtime/include/Exception.hpp index 2e785231e2..14c34280a3 100644 --- a/runtime/include/Exception.hpp +++ b/runtime/include/Exception.hpp @@ -51,7 +51,7 @@ namespace Catalyst::Runtime { * @brief This is the general exception thrown by Catalyst for runtime errors * that is derived from `std::exception`. */ -class RuntimeException : public std::exception { +class __attribute__((visibility("default"))) RuntimeException : public std::exception { private: const std::string err_msg; diff --git a/runtime/include/RemoteCAPI.h b/runtime/include/RemoteCAPI.h new file mode 100644 index 0000000000..d6a778fe39 --- /dev/null +++ b/runtime/include/RemoteCAPI.h @@ -0,0 +1,47 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifndef REMOTECAPI_H +#define REMOTECAPI_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// Remote Runtime API. This is expected to be called by the host program to establish a remote +// session. Including open a session, send the binary to the remote, launch the kernel on the remote +// and close the session. +int __catalyst__remote__open(const char *addr); +int __catalyst__remote__send_binary(const char *addr, const char *path, uint32_t format); +void __catalyst__remote__launch(const char *addr, const char *entry_symbol, size_t num_inputs, + void *const *input_descs, const size_t *input_ranks, + const size_t *input_elem_sizes, size_t num_outputs, + void *const *output_descs, const size_t *output_ranks, + const size_t *output_elem_sizes); +int __catalyst__remote__call_wrapper(const char *addr, const char *symbol, const char *args_buf, + size_t args_size, void **out_buf, size_t *out_size); +void __catalyst__remote__free_result(void *buf); + +int __catalyst__remote__close(); +const char *__catalyst__remote__last_error(); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/runtime/include/RuntimeCAPI.h b/runtime/include/RuntimeCAPI.h index cb2e637d5f..61c5016705 100644 --- a/runtime/include/RuntimeCAPI.h +++ b/runtime/include/RuntimeCAPI.h @@ -119,6 +119,9 @@ RESULT *__catalyst__mbqc__measure_in_basis(QUBIT *, uint32_t, double, int32_t); // Async runtime error void __catalyst__host__rt__unrecoverable_error(); +// Allocate a host buffer of `size` bytes and register it with the runtime's memory manager +void *__catalyst__rt__alloc_managed(size_t size); + #ifdef __cplusplus } // extern "C" #endif diff --git a/runtime/lib/CMakeLists.txt b/runtime/lib/CMakeLists.txt index a4cd496c33..d4abd8e0e8 100644 --- a/runtime/lib/CMakeLists.txt +++ b/runtime/lib/CMakeLists.txt @@ -50,3 +50,7 @@ add_subdirectory(QEC) if(ENABLE_OQD) add_subdirectory(OQDcapi) endif() + +if(ENABLE_REMOTE) +add_subdirectory(remote) +endif() diff --git a/runtime/lib/capi/ExecutionContext.hpp b/runtime/lib/capi/ExecutionContext.hpp index 5bb1f4d825..5216912740 100644 --- a/runtime/lib/capi/ExecutionContext.hpp +++ b/runtime/lib/capi/ExecutionContext.hpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -88,6 +89,15 @@ class SharedLibraryManager final { #endif _handler = dlopen(filename.c_str(), rtld_flags); + if (!_handler) { + // Fall back to the base name of the library if the full path is not found. + std::string stem = std::filesystem::path(filename).stem().string(); + for (const char *ext : {".so", ".dylib", ".dll"}) { + if ((_handler = dlopen((stem + ext).c_str(), rtld_flags))) { + break; + } + } + } RT_FAIL_IF(!_handler, dlerror()); } diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 4c26399516..f0ac427f10 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -169,6 +169,8 @@ void *_mlir_memref_to_llvm_alloc(size_t size) return ptr; } +void *__catalyst__rt__alloc_managed(size_t size) { return _mlir_memref_to_llvm_alloc(size); } + void *_mlir_memref_to_llvm_aligned_alloc(size_t alignment, size_t size) { void *ptr = aligned_alloc(alignment, size); diff --git a/runtime/lib/remote/CMakeLists.txt b/runtime/lib/remote/CMakeLists.txt new file mode 100644 index 0000000000..37cd497abc --- /dev/null +++ b/runtime/lib/remote/CMakeLists.txt @@ -0,0 +1,57 @@ +############################################### +# library catalyst_remote_session + rt_remote # +############################################### + +find_package(LLVM CONFIG REQUIRED) + +llvm_map_components_to_libnames(_remote_llvm_libs + core support orcjit jitlink object passes + ${LLVM_TARGETS_TO_BUILD} +) + +# catalyst_remote_session + +add_library(catalyst_remote_session STATIC RemoteSession.cpp) + +target_include_directories(catalyst_remote_session + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_include_directories(catalyst_remote_session SYSTEM PRIVATE + ${LLVM_INCLUDE_DIRS} +) + +target_link_libraries(catalyst_remote_session PRIVATE + ${_remote_llvm_libs} + pthread + ${CMAKE_DL_LIBS} +) + +target_compile_options(catalyst_remote_session PRIVATE -fno-rtti) + +set_property(TARGET catalyst_remote_session PROPERTY POSITION_INDEPENDENT_CODE ON) + +# rt_remote + +add_library(rt_remote SHARED RemoteRuntime.cpp) + +target_include_directories(rt_remote + PRIVATE + ${runtime_includes} + ${CMAKE_CURRENT_SOURCE_DIR} +) + +target_link_libraries(rt_remote PRIVATE + catalyst_remote_session + rt_capi + pthread + ${CMAKE_DL_LIBS} +) + +set_property(TARGET rt_remote PROPERTY POSITION_INDEPENDENT_CODE ON) + +if(NOT APPLE) + set_property(TARGET rt_remote APPEND PROPERTY BUILD_RPATH $ORIGIN) +else() + set_property(TARGET rt_remote APPEND PROPERTY BUILD_RPATH @loader_path) +endif() diff --git a/runtime/lib/remote/RemoteRuntime.cpp b/runtime/lib/remote/RemoteRuntime.cpp new file mode 100644 index 0000000000..a97aec551a --- /dev/null +++ b/runtime/lib/remote/RemoteRuntime.cpp @@ -0,0 +1,274 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "Exception.hpp" +#include "RemoteCAPI.h" +#include "RemoteSession.hpp" + +namespace { + +struct RemoteEntry { + catalyst::remote::RemoteSession *session = nullptr; // The remote session handle. + std::set loaded_paths; // The paths of the binaries that are going to be loaded + // into the remote session. + std::mutex mu; // The mutex to protect the loaded paths. +}; + +std::mutex g_map_mu; + +// Each address has its own RemoteEntry, so we can dispatch the object file to different remote +// sessions. +std::map> remote_sessions; + +thread_local std::string g_remote_runtime_error; +void set_remote_runtime_error(const char *msg) +{ + if (msg) { + g_remote_runtime_error = msg; + } + else { + g_remote_runtime_error = "(unknown)"; + } +} + +// DEBUG logs +bool remote_verbose() +{ + static const bool v = []() { + const char *e = std::getenv("CATALYST_REMOTE_VERBOSE"); + return e && *e && *e != '0'; + }(); + return v; +} + +// Look up or create the entry for `addr`. +RemoteEntry *find_or_create_entry(const char *addr, bool create_if_missing) +{ + if (!addr || !*addr) { + return nullptr; + } + std::lock_guard lock(g_map_mu); + auto it = remote_sessions.find(addr); + if (it != remote_sessions.end()) { + return it->second.get(); + } + if (!create_if_missing) { + return nullptr; + } + auto inserted = remote_sessions.emplace(std::string(addr), std::make_unique()); + return inserted.first->second.get(); +} + +} // namespace + +extern "C" { + +int __catalyst__remote__open(const char *addr) +{ + if (!addr || !*addr) { + set_remote_runtime_error("Empty address"); + return -1; + } + RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/true); + std::lock_guard lock(entry->mu); + if (entry->session) { + return 0; // idempotent per addr + } + if (remote_verbose()) { + std::fprintf(stderr, "[remote] open(addr=%s)\n", addr); + } + entry->session = catalyst::remote::open(addr); + if (!entry->session) { + std::string msg = "Could not connect to catalyst-executor at "; + msg += addr; + msg += ": "; + msg += catalyst::remote::last_error(); + set_remote_runtime_error(msg.c_str()); + std::lock_guard mapLock(g_map_mu); + remote_sessions.erase(addr); + return -1; + } + if (remote_verbose()) { + std::fprintf(stderr, "[remote] open(%s) OK\n", addr); + } + return 0; +} + +int __catalyst__remote__send_binary(const char *addr, const char *path, uint32_t format) +{ + RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/false); + if (!entry) { + set_remote_runtime_error("No session found, call __catalyst__remote__open first."); + return -1; + } + std::lock_guard lock(entry->mu); + if (!entry->session) { + std::string msg = "__catalyst__remote__send_binary("; + msg += addr; + msg += "): session is closed."; + set_remote_runtime_error(msg.c_str()); + return -1; + } + if (!path || !*path) { + return 0; + } + std::string key(path); + if (!entry->loaded_paths.insert(key).second) { + return 0; + } + if (remote_verbose()) { + std::fprintf(stderr, "[remote] send_binary(addr=%s, path=%s, format=%u)\n", addr, path, + format); + } + + int rc = 0; + switch (format) { + case 0: + rc = catalyst::remote::load_object_path(entry->session, path); + break; + case 1: + rc = catalyst::remote::load_asset_path(entry->session, path); + break; + default: + std::string msg = "unknown binary format tag "; + msg += std::to_string(format); + set_remote_runtime_error(msg.c_str()); + rc = -1; + } + + if (rc != 0) { + set_remote_runtime_error(catalyst::remote::last_error()); + entry->loaded_paths.erase(key); + return -1; + } + return 0; +} + +/** + * @brief Generic ORC wrapper-function call by symbol name. Returns 0 on success, -1 on error. + * + * @param addr The address of the remote session. + * @param symbol The symbol of the function to call. + * @param args_buf The buffer of the arguments. + * @param args_size The size of the arguments. + * @param out_buf The buffer of the result. + * @param out_size The size of the result. + * @return int 0 on success, -1 on error. + */ +int __catalyst__remote__call_wrapper(const char *addr, const char *symbol, const char *args_buf, + size_t args_size, void **out_buf, size_t *out_size) +{ + if (out_buf) { + *out_buf = nullptr; + } + if (out_size) { + *out_size = 0; + } + RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/false); + if (!entry) { + set_remote_runtime_error("No session found, call __catalyst__remote__open first."); + return -1; + } + std::lock_guard lock(entry->mu); + if (!entry->session) { + set_remote_runtime_error("Session is closed"); + return -1; + } + if (!symbol || !*symbol) { + set_remote_runtime_error("Empty symbol passed to __catalyst__remote__call_wrapper"); + return -1; + } + if (remote_verbose()) { + std::fprintf(stderr, "[remote] call_wrapper(addr=%s, sym=%s, in_size=%zu)\n", addr, symbol, + args_size); + } + char *buf = nullptr; + size_t n = 0; + int rc = catalyst::remote::call_wrapper_raw(entry->session, symbol, args_buf, args_size, &buf, + &n); + if (rc != 0) { + set_remote_runtime_error(catalyst::remote::last_error()); + return -1; + } + if (out_buf) { + *out_buf = buf; + } + else { + std::free(buf); // caller didn't want the bytes back + } + if (out_size) { + *out_size = n; + } + return 0; +} + +void __catalyst__remote__free_result(void *buf) { std::free(buf); } + +int __catalyst__remote__close() +{ + std::lock_guard mapLock(g_map_mu); + for (auto &[addr, entry] : remote_sessions) { + std::lock_guard lock(entry->mu); + if (entry->session) { + catalyst::remote::close(entry->session); + entry->session = nullptr; + entry->loaded_paths.clear(); + } + } + remote_sessions.clear(); + return 0; +} + +const char *__catalyst__remote__last_error() { return g_remote_runtime_error.c_str(); } + +void __catalyst__remote__launch(const char *addr, const char *entry_symbol, size_t num_inputs, + void *const *input_descs, const size_t *input_ranks, + const size_t *input_elem_sizes, size_t num_outputs, + void *const *output_descs, const size_t *output_ranks, + const size_t *output_elem_sizes) +{ + RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/false); + if (!entry) { + RT_FAIL("Can't find opened session"); + } + + std::lock_guard lock(entry->mu); + if (remote_verbose()) { + std::fprintf(stderr, "[remote] launch(addr=%s, symbol=%s, n_in=%zu, n_out=%zu)\n", addr, + entry_symbol, num_inputs, num_outputs); + } + if (!entry->session) { + RT_FAIL("Session is closed"); + } + uint64_t entry_addr = catalyst::remote::lookup(entry->session, entry_symbol); + if (!entry_addr) { + RT_FAIL(catalyst::remote::last_error()); + } + if (catalyst::remote::invoke_kernel(entry->session, entry_addr, num_inputs, input_descs, + input_ranks, input_elem_sizes, num_outputs, output_descs, + output_ranks, output_elem_sizes) != 0) { + RT_FAIL(catalyst::remote::last_error()); + } +} + +} // extern "C" diff --git a/runtime/lib/remote/RemoteSession.cpp b/runtime/lib/remote/RemoteSession.cpp new file mode 100644 index 0000000000..5d0ce4e293 --- /dev/null +++ b/runtime/lib/remote/RemoteSession.cpp @@ -0,0 +1,747 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RemoteSession.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/EPCDynamicLibrarySearchGenerator.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" +#include "llvm/ExecutionEngine/Orc/MapperJITLinkMemoryManager.h" +#include "llvm/ExecutionEngine/Orc/MemoryAccess.h" +#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" +#include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h" +#include "llvm/ExecutionEngine/Orc/Shared/SimpleRemoteEPCUtils.h" +#include "llvm/ExecutionEngine/Orc/Shared/TargetProcessControlTypes.h" +#include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h" +#include "llvm/ExecutionEngine/Orc/SimpleRemoteMemoryMapper.h" +#include "llvm/ExecutionEngine/Orc/TaskDispatch.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/TargetSelect.h" + +#include +#include + +using namespace llvm; +using namespace llvm::orc; + +// Public CAPI from librt_capi (declared in runtime/include/RuntimeCAPI.h). +// Allocates a host buffer and registers it with CTX->getMemoryManager(). +extern "C" void *__catalyst__rt__alloc_managed(size_t size); + +namespace { + +// The connection behaviours below are mirrored from `llvm/tools/llvm-jitlink/llvm-jitlink.cpp`, +// and the Session class is mirrored from the official LLVM JIT tutorial: +// `https://llvm.org/docs/tutorial/BuildingAJIT1.html`. + +// Memref descriptor layout: allocated*, aligned*, offset, sizes[], strides[]. +// Offsets of the fixed-size prefix are independent of rank. +// The information can be obtained from `mlir/ExecutionEngine/CRunnerUtils.h`. +constexpr size_t kAllocatedOff = 0; +constexpr size_t kAlignedOff = sizeof(void *); +constexpr size_t kOffsetOff = sizeof(void *) * 2; +constexpr size_t kShapeOff = sizeof(void *) * 2 + sizeof(size_t); + +void initialize_targets() +{ + static const bool inited = []() { + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + return true; + }(); + (void)inited; +} + +// For avoiding the error message being overwritten by subsequent errors in async jobs. +// We use thread_local to store the error message. +thread_local std::string g_last_error; +void set_error(const std::string &msg) { g_last_error = msg; } +void clear_error() { g_last_error.clear(); } + +void check(Error E, const Twine &what) +{ + if (E) { + throw std::runtime_error((what + ": " + toString(std::move(E))).str()); + } +} + +// unwrap LLVM Expected to C++ exception +template T unwrap(Expected v, const Twine &what) +{ + check(v.takeError(), what); + return std::move(*v); +} + +// TCP connect (mirrored from llvm-jitlink) +std::string OutOfProcessExecutorConnect; + +Error createTCPSocketError(Twine Details) +{ + return make_error("Failed to connect TCP socket '" + + Twine(OutOfProcessExecutorConnect) + "': " + Details, + inconvertibleErrorCode()); +} + +Expected connectTCPSocket(std::string Host, std::string PortStr) +{ + addrinfo *AI; + addrinfo Hints{}; + Hints.ai_family = AF_INET; + Hints.ai_socktype = SOCK_STREAM; + Hints.ai_flags = AI_NUMERICSERV; + + if (int EC = getaddrinfo(Host.c_str(), PortStr.c_str(), &Hints, &AI)) { + return createTCPSocketError("Address resolution failed (" + StringRef(gai_strerror(EC)) + + ")"); + } + + // Cycle through the returned addrinfo structures and connect to the first + // reachable endpoint. + int SockFD; + addrinfo *Server; + for (Server = AI; Server != nullptr; Server = Server->ai_next) { + // socket might fail, e.g. if the address family is not supported. Skip to + // the next addrinfo structure in such a case. + if ((SockFD = socket(AI->ai_family, AI->ai_socktype, AI->ai_protocol)) < 0) + continue; + + // If connect returns null, we exit the loop with a working socket. + if (connect(SockFD, Server->ai_addr, Server->ai_addrlen) == 0) + break; + + close(SockFD); + } + freeaddrinfo(AI); + + // If we reached the end of the loop without connecting to a valid endpoint, + // dump the last error that was logged in socket() or connect(). + if (Server == nullptr) { + return createTCPSocketError(std::strerror(errno)); + } + + return SockFD; +} + +// Slab-based JIT-link memory manager: reserve a 1 GB slab on the remote once +Expected> +createSimpleRemoteMemoryManager(SimpleRemoteEPC &SREPC) +{ + SimpleRemoteMemoryMapper::SymbolAddrs SAs; + if (auto Err = SREPC.getBootstrapSymbols( + {{SAs.Instance, rt::SimpleExecutorMemoryManagerInstanceName}, + {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName}, + {SAs.Initialize, rt::SimpleExecutorMemoryManagerInitializeWrapperName}, + {SAs.Deinitialize, rt::SimpleExecutorMemoryManagerDeinitializeWrapperName}, + {SAs.Release, rt::SimpleExecutorMemoryManagerReleaseWrapperName}})) { + return std::move(Err); + } + // 1 GB for object's sections (e.g .text, .rodata, ...) + // It will be released once the Session is destroyed. + size_t SlabSize = 1024 * 1024 * 1024; + return MapperJITLinkMemoryManager::CreateWithMapper(SlabSize, SREPC, + SAs); +} + +Expected> getFile(const Twine &filename) +{ + auto F = MemoryBuffer::getFile(filename); + if (F) { + return std::move(*F); + } + return createFileError(filename, F.getError()); +} + +} // namespace + +namespace catalyst::remote { + +struct RemoteSession { + std::unique_ptr ES; + + DataLayout DL; + + MangleAndInterner Mangle; + ObjectLinkingLayer ObjectLayer; + + JITDylib &MainJD; + + ExecutorAddr alloc_fn{0}; + ExecutorAddr free_fn{0}; + ExecutorAddr invoke_fn{0}; + ExecutorAddr store_asset_fn{0}; + + RemoteSession(std::unique_ptr es, DataLayout dl) + : ES(std::move(es)), DL(std::move(dl)), Mangle(*this->ES, this->DL), ObjectLayer(*this->ES), + MainJD(this->ES->createBareJITDylib("
")) + { + MainJD.addGenerator( + cantFail(EPCDynamicLibrarySearchGenerator::GetForTargetProcess(*this->ES))); + } + + ~RemoteSession() + { + if (auto Err = ES->endSession()) { + ES->reportError(std::move(Err)); + } + } + + static Expected> Create(StringRef remote_addr) + { + initialize_targets(); + + OutOfProcessExecutorConnect = remote_addr.str(); + auto [Host, PortStr] = remote_addr.split(':'); + if (Host.empty()) { + return createTCPSocketError("Host name for -" + OutOfProcessExecutorConnect + + " can not be empty"); + } + if (PortStr.empty()) { + return createTCPSocketError("Port number in -" + OutOfProcessExecutorConnect + + " can not be empty"); + } + + auto SockFD = connectTCPSocket(Host.str(), PortStr.str()); + if (!SockFD) { + return SockFD.takeError(); + } + + auto setup = SimpleRemoteEPC::Setup(); + setup.CreateMemoryManager = createSimpleRemoteMemoryManager; + auto EPC = SimpleRemoteEPC::Create( + std::make_unique(std::nullopt), std::move(setup), + *SockFD, *SockFD); + if (!EPC) { + return EPC.takeError(); + } + + JITTargetMachineBuilder JTMB((*EPC)->getTargetTriple()); + auto DL = JTMB.getDefaultDataLayoutForTarget(); + if (!DL) { + return DL.takeError(); + } + + auto ES = std::make_unique(std::move(*EPC)); + return std::make_unique(std::move(ES), std::move(*DL)); + } + + Error addObjectFile(std::unique_ptr Buf) + { + return ObjectLayer.add(MainJD, std::move(Buf)); + } + + ExecutorAddr lookupSym(StringRef Name) + { + auto Sym = unwrap(ES->lookup({&MainJD}, Mangle(Name.str())), "lookup(" + Name + ")"); + return Sym.getAddress(); + } + + ExecutorProcessControl &getEPC() { return ES->getExecutorProcessControl(); } +}; + +// --------------------------------------------------------------------------- +// Memref Marshalling Helpers +// --------------------------------------------------------------------------- + +namespace { + +ExecutorAddr remote_alloc(RemoteSession *s, size_t size) +{ + ExecutorAddr ret; + auto &epc = s->getEPC(); + std::string error_prefix = "alloc(" + std::to_string(size) + ")"; + check(epc.callSPSWrapper(s->alloc_fn, ret, + static_cast(size)), + error_prefix); + if (!ret) { + throw std::runtime_error(error_prefix + " out of memory"); + } + return ret; +} + +void remote_free(RemoteSession *s, ExecutorAddr addr) +{ + auto &epc = s->getEPC(); + if (auto err = epc.callSPSWrapper(s->free_fn, addr)) { + consumeError(std::move(err)); + } +} + +void remote_write(RemoteSession *s, ExecutorAddr addr, const void *data, size_t size) +{ + auto &epc = s->getEPC(); + tpctypes::BufferWrite w{addr, ArrayRef(static_cast(data), size)}; + check(epc.getMemoryAccess().writeBuffers({w}), "write"); +} + +void remote_read(RemoteSession *s, ExecutorAddr addr, void *data, size_t size) +{ + ExecutorAddrRange r(addr, addr + size); + auto out = unwrap(s->getEPC().getMemoryAccess().readBuffers({r}), "read"); + if (out.empty()) { + throw std::runtime_error("read: empty read"); + } + if (out[0].size() != size) { + throw std::runtime_error("read: size mismatch"); + } + std::memcpy(data, out[0].data(), size); +} + +void remote_invoke(RemoteSession *s, ExecutorAddr entry, ArrayRef arg_addrs) +{ + auto &epc = s->getEPC(); + check(epc.callSPSWrapper)>(s->invoke_fn, + entry, arg_addrs), + "remote_invoke"); +} + +// Memref descriptors are layout-equivalent to: +// { void *allocated; void *aligned; size_t offset; size_t sizes[rank]; size_t strides[rank]; } +size_t memref_desc_size(size_t rank) +{ + return sizeof(void *) // void *allocated + + sizeof(void *) // void *aligned + + sizeof(size_t) // size_t offset + + sizeof(size_t) * rank // size_t sizes[rank] + + sizeof(size_t) * rank // size_t strides[rank] + ; +} + +size_t memref_data_size(const char *desc, size_t rank, size_t elem_size) +{ + if (rank == 0) { + return elem_size; + } + const size_t *shape = reinterpret_cast(desc + kShapeOff); + return std::accumulate(shape, shape + rank, elem_size, std::multiplies()); +} + +class RemoteAllocator { + private: + RemoteSession *sess; + std::vector addrs; + + public: + explicit RemoteAllocator(RemoteSession *s) : sess(s) {} + ~RemoteAllocator() + { + for (ExecutorAddr a : addrs) { + remote_free(sess, a); + } + } + RemoteAllocator(const RemoteAllocator &) = delete; + RemoteAllocator &operator=(const RemoteAllocator &) = delete; + + ExecutorAddr alloc(size_t size) + { + ExecutorAddr addr = remote_alloc(sess, size); + addrs.push_back(addr); + return addr; + } +}; + +} // namespace + +// --------------------------------------------------------------------------- +// RemoteSession's Exported APIs +// --------------------------------------------------------------------------- + +/** + * @brief Open a session to the remote device. + * + * @param remote_addr the remote address + * @return RemoteSession * the session object + */ +RemoteSession *open(const char *remote_addr) +{ + clear_error(); + try { + auto s = unwrap(RemoteSession::Create(remote_addr), "open(" + Twine(remote_addr) + ")"); + check(s->getEPC().getBootstrapSymbols({{s->alloc_fn, "catalyst_remote_alloc"}, + {s->free_fn, "catalyst_remote_free"}, + {s->invoke_fn, "catalyst_remote_invoke"}, + {s->store_asset_fn, "catalyst_remote_store_asset"}}), + "getBootstrapSymbols"); + return s.release(); + } + catch (const std::exception &e) { + set_error(e.what()); + return nullptr; + } +} + +/** + * @brief Close the session to the remote device. + * The remote session's destructor will handle the cleanup of the session. + * + * @param s the session object + */ +void close(RemoteSession *s) { delete s; } + +/** + * @brief Load an object file (cross-compiled for the remote arch) into the remote JIT. + * + * @param s the session object + * @param path the path to the object file + * @return int 0 on success, non-zero on error + */ +int load_object_path(RemoteSession *s, const char *path) +{ + clear_error(); + try { + auto buf = unwrap(getFile(path), "getFile(" + Twine(path) + ")"); + check(s->addObjectFile(std::move(buf)), "addObjectFile"); + return 0; + } + catch (const std::exception &e) { + set_error(e.what()); + return -1; + } +} + +/** + * @brief Load an asset file into the remote JIT. + * + * @param s the session object + * @param path the path to the asset file + * @return int 0 on success, -1 on error + */ +int load_asset_path(RemoteSession *s, const char *path) +{ + clear_error(); + try { + auto buf = unwrap(getFile(path), "getFile(" + Twine(path) + ")"); + ArrayRef bytes(buf->getBufferStart(), buf->getBufferSize()); + + int32_t rc = 0; + check(s->getEPC().callSPSWrapper, shared::SPSString)>( + s->store_asset_fn, rc, bytes, std::string(path)), + "store_asset"); + if (rc != 0) { + throw std::runtime_error("Got non-zero status from store_asset(" + std::string(path) + + "): " + std::to_string(rc)); + } + return 0; + } + catch (const std::exception &e) { + set_error(e.what()); + return -1; + } +} + +/** + * @brief Generic raw ORC wrapper-function call. Looks `sym` up on the executor, invokes its wrapper + * with `(args_buf, args_size)`, and copies the resulting byte buffer into a `out_buf` with the size + * of `out_size`. The caller is responsible for `free()` the result buffer. + * + * @param s The session object. + * @param sym The symbol of the function to call. + * @param args_buf The buffer of the arguments. + * @param args_size The size of the arguments. + * @param out_buf The buffer of the result. + * @param out_size The size of the result. + * @return int 0 on success, -1 on error. + */ +int call_wrapper_raw(RemoteSession *s, const char *sym, const char *args_buf, size_t args_size, + char **out_buf, size_t *out_size) +{ + clear_error(); + *out_buf = nullptr; + *out_size = 0; + try { + ExecutorAddr fn = s->lookupSym(sym); + if (!fn) { + throw std::runtime_error(std::string("symbol not found: ") + sym); + } + auto result = s->getEPC().callWrapper(fn, ArrayRef(args_buf, args_size)); + size_t n = result.size(); + char *buf = nullptr; + if (n > 0) { + buf = static_cast(std::malloc(n)); + if (!buf) { + throw std::runtime_error("malloc failed for wrapper result"); + } + std::memcpy(buf, result.data(), n); + } + *out_buf = buf; + *out_size = n; + return 0; + } + catch (const std::exception &e) { + set_error(e.what()); + return -1; + } +} + +/** + * @brief Lookup the address of a symbol in the remote device. + * + * @param s the session object + * @param name the name of the symbol + * @return uint64_t the address of the symbol + */ +uint64_t lookup(RemoteSession *s, const char *name) +{ + clear_error(); + try { + return s->lookupSym(name).getValue(); + } + catch (const std::exception &e) { + set_error(e.what()); + return 0; + } +} + +/** + * @brief Run the kernel as a main function (take argv as arguments, argc is the length of argv). + * + * @param s the session object + * @param entry the entry function address + * @param argv the command line arguments + * @return int32_t the exit code + */ +int32_t run_as_main(RemoteSession *s, uint64_t entry_addr, int argc, const char *const *argv) +{ + clear_error(); + try { + std::vector args; + args.reserve(argc); + for (int i = 0; i < argc; ++i) { + args.emplace_back(argv[i]); + } + return unwrap(s->getEPC().runAsMain(ExecutorAddr(entry_addr), args), "run_as_main"); + } + catch (const std::exception &e) { + set_error(e.what()); + return -1; + } +} + +/** + * @brief Push one host memref to the remote: + * 1. allocates the data buffer on the remote + * 2. allocates the descriptor on the remote (which has a pointer to the data buffer) + * 3. returns the descriptor's remote addr. + * 4. If `copy_data` is true, the data will be copied to the remote. + * It's used for input memrefs like arguments. + * In the case of output memrefs, the data will be copied back to the host, + * so we don't need to copy the data to the remote. + * + * @param s the session object + * @param alloc the remote allocator + * @param host_desc the host memref descriptor + * @param rank the rank of the memref + * @param elem_size the element size of the memref + * @param copy_data whether to copy the data to the remote + * @return ExecutorAddr the remote address of the memref descriptor + */ +ExecutorAddr push_memref(RemoteSession *s, RemoteAllocator &alloc, void *host_desc, size_t rank, + size_t elem_size, bool copy_data) +{ + char *desc_host = static_cast(host_desc); + size_t desc_size = memref_desc_size(rank); + size_t data_size = memref_data_size(desc_host, rank, elem_size); + + // memref descriptor layout: + // ┌──────────────────────┐ ┌──────┐ + // │memref .allocated┼┬─►│buffer│ + // │descriptor .aligned ─┼┘ └──────┘ + // │ .offset │ + // │ .shape │ + // │ .strides │ + // └──────────────────────┘ + + std::vector desc(desc_size); + std::memcpy(desc.data(), desc_host, desc_size); + + ExecutorAddr data_remote = ExecutorAddr(0); + if (data_size > 0) { + data_remote = alloc.alloc(data_size); + if (copy_data) { + void *aligned_host = *reinterpret_cast(desc_host + kAlignedOff); + if (aligned_host) { + remote_write(s, data_remote, aligned_host, data_size); + } + } + } + std::memcpy(desc.data() + kAllocatedOff, &data_remote, sizeof(uintptr_t)); + std::memcpy(desc.data() + kAlignedOff, &data_remote, sizeof(uintptr_t)); + std::memset(desc.data() + kOffsetOff, 0, sizeof(int64_t)); + + ExecutorAddr desc_remote = alloc.alloc(desc_size); + remote_write(s, desc_remote, desc.data(), desc.size()); + return desc_remote; +} + +/** + * @brief Pull a remote memref descriptor + its data back into the host descriptor. + * + * @param s the session object + * @param remote_desc the remote address of the memref descriptor + * @param host_desc the host memref descriptor + * @param rank the rank of the memref + * @param elem_size the element size of the memref + */ +void pull_memref(RemoteSession *s, ExecutorAddr remote_desc, void *host_desc, size_t rank, + size_t elem_size) +{ + size_t desc_size = memref_desc_size(rank); + std::vector desc(desc_size); + remote_read(s, remote_desc, desc.data(), desc.size()); + + uintptr_t aligned_remote; + std::memcpy(&aligned_remote, desc.data() + kAlignedOff, sizeof(uintptr_t)); + + size_t data_size = memref_data_size(desc.data(), rank, elem_size); + size_t alloc_size = std::max(data_size, 1); + void *aligned_host = __catalyst__rt__alloc_managed(alloc_size); + if (data_size && aligned_remote) { + remote_read(s, ExecutorAddr(aligned_remote), aligned_host, data_size); + } + uintptr_t aligned_addr = reinterpret_cast(aligned_host); + std::memcpy(desc.data() + kAllocatedOff, &aligned_addr, sizeof(uintptr_t)); + std::memcpy(desc.data() + kAlignedOff, &aligned_addr, sizeof(uintptr_t)); + std::memcpy(host_desc, desc.data(), desc_size); +} + +/** + * @brief Invoke a remote kernel. + * + * @param s the session object + * @param entry_addr the address of the kernel entry function + * @param num_inputs the number of input memrefs + * @param input_descs the input memref descriptors + * @param input_ranks the ranks of the input memrefs + * @param input_elem_sizes the element sizes of the input memrefs + * @param num_outputs the number of output memrefs + * @param output_descs the output memref descriptors + * @param output_ranks the ranks of the output memrefs + * @param output_elem_sizes the element sizes of the output memrefs + * @return int 0 on success, non-zero on error + */ +int invoke_kernel(RemoteSession *s, uint64_t entry_addr, size_t num_inputs, + void *const *input_descs, const size_t *input_ranks, + const size_t *input_elem_sizes, size_t num_outputs, void *const *output_descs, + const size_t *output_ranks, const size_t *output_elem_sizes) +{ + clear_error(); + RemoteAllocator allocator(s); + try { + // The remote executor's catalyst_remote_invoke calls the entry as Catalyst's pyface ABI: + // `void(rv*, av*)`. + + // Layout (av): + // av (argument) is a struct whose Nth field is a pointer to the Nth input memref + // descriptor. So av = [N x uintptr_t] (array of remote descriptor addresses). + // + // av ──►┌───────────┐ ┌──────────────────────┐ ┌──────┐ + // │slot0 (ptr)┼────►│memref .allocated┼┬─►│buffer│ + // │slot1 │ │descriptor .aligned ─┼┘ └──────┘ + // │slot2 │ │ .offset │ + // │ ... │ │ .shape │ + // │ │ │ .strides │ + // └───────────┘ └──────────────────────┘ + + // Layout (rv): + // rv is a struct whose Nth field is the Nth output memref descriptor (not a pointer) + // + // rv ──►┌─────┬─────┬─────┬─────┬─────────┐ + // │desc0│desc1│desc2│desc3│ ... │ + // └─────┴─────┴─────┴─────┴─────────┘ + // Each slot maps to a output memref descriptor + + // Step 1. Push the input memref descriptors (with data) to the remote. + std::vector input_remote_descs(num_inputs); + for (size_t i = 0; i < num_inputs; ++i) { + input_remote_descs[i] = push_memref(s, allocator, input_descs[i], input_ranks[i], + input_elem_sizes[i], /*copy_data=*/true); + } + + // Step 2. Allocate a remote buffer holding the input memref descriptors' pointers. + ExecutorAddr av_remote = ExecutorAddr(0); + if (num_inputs > 0) { + av_remote = allocator.alloc(sizeof(uintptr_t) * num_inputs); + std::vector av_buf(num_inputs); + for (size_t i = 0; i < num_inputs; ++i) { + av_buf[i] = input_remote_descs[i].getValue(); + } + remote_write(s, av_remote, av_buf.data(), sizeof(uintptr_t) * num_inputs); + } + + // Step 3. Allocate a remote buffer for kernel to write the output memref descriptors. + std::vector output_offsets(num_outputs); + size_t rv_total = 0; + for (size_t i = 0; i < num_outputs; ++i) { + output_offsets[i] = rv_total; + rv_total += memref_desc_size(output_ranks[i]); + } + ExecutorAddr rv_remote = ExecutorAddr(0); + if (rv_total > 0) { + rv_remote = allocator.alloc(rv_total); + } + + // Step 4. Invoke the kernel remotely. + std::vector arg_addrs = {rv_remote, av_remote}; + remote_invoke(s, ExecutorAddr(entry_addr), arg_addrs); + + // Step 5. Pull each output descriptor back from rv buffer. + if (rv_total > 0) { + std::vector rv_buf(rv_total); + remote_read(s, rv_remote, rv_buf.data(), rv_total); + for (size_t i = 0; i < num_outputs; ++i) { + size_t desc_size = memref_desc_size(output_ranks[i]); + size_t elem_size = output_elem_sizes[i]; + char *desc = rv_buf.data() + output_offsets[i]; + + uintptr_t aligned_remote; + std::memcpy(&aligned_remote, desc + kAlignedOff, sizeof(uintptr_t)); + + size_t data_size = memref_data_size(desc, output_ranks[i], elem_size); + size_t alloc_size = std::max(data_size, 1); + void *aligned_host = __catalyst__rt__alloc_managed(alloc_size); + if (data_size && aligned_remote) { + remote_read(s, ExecutorAddr(aligned_remote), aligned_host, data_size); + } + uintptr_t aligned_addr = reinterpret_cast(aligned_host); + std::memcpy(desc + kAllocatedOff, &aligned_addr, sizeof(uintptr_t)); + std::memcpy(desc + kAlignedOff, &aligned_addr, sizeof(uintptr_t)); + std::memcpy(output_descs[i], desc, desc_size); + } + } + return 0; + } + catch (const std::exception &e) { + set_error(e.what()); + return -1; + } +} + +const char *last_error() { return g_last_error.c_str(); } + +} // namespace catalyst::remote diff --git a/runtime/lib/remote/RemoteSession.hpp b/runtime/lib/remote/RemoteSession.hpp new file mode 100644 index 0000000000..c8eb18f698 --- /dev/null +++ b/runtime/lib/remote/RemoteSession.hpp @@ -0,0 +1,55 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace catalyst::remote { + +// Opaque session handle. Created by open(), released by close(). +struct RemoteSession; + +// Open a TCP session to a `host:port` executor. Returns nullptr on error. +RemoteSession *open(const char *remote_addr); + +void close(RemoteSession *s); + +// Load an object file into the remote JIT. Returns 0 on success, -1 on error. +int load_object_path(RemoteSession *s, const char *path); + +// Load an asset file into the remote JIT. Returns 0 on success, -1 on error. +int load_asset_path(RemoteSession *s, const char *path); + +// Generic ORC wrapper-function call by symbol name. Returns 0 on success, -1 on error. +int call_wrapper_raw(RemoteSession *s, const char *sym, const char *args_buf, size_t args_size, + char **out_buf, size_t *out_size); + +// Look up a symbol address on the remote. Returns 0 on error. +uint64_t lookup(RemoteSession *s, const char *name); + +// Run a remote function as `main(argc, argv)`. +int32_t run_as_main(RemoteSession *s, uint64_t entry_addr, int argc, const char *const *argv); + +// Invoke a remote kernel. Returns 0 on success, -1 on error. +int invoke_kernel(RemoteSession *s, uint64_t entry_addr, size_t num_inputs, + void *const *input_descs, const size_t *input_ranks, + const size_t *input_elem_sizes, size_t num_outputs, void *const *output_descs, + const size_t *output_ranks, const size_t *output_elem_sizes); + +// Last error message. +const char *last_error(); + +} // namespace catalyst::remote From eb9d711e606363d5da8da3d07f5a88e59821ae67 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Wed, 27 May 2026 12:16:46 -0400 Subject: [PATCH 2/6] support memref inputs for remote_lib_call and close session with python exit --- .../BufferizableOpInterfaceImpl.cpp | 9 +- .../Catalyst/Transforms/catalyst_to_llvm.cpp | 107 ++++++++++++------ runtime/lib/remote/RemoteRuntime.cpp | 8 ++ 3 files changed, 84 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp index 2a3a6947fc..802e602373 100644 --- a/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Catalyst/Transforms/BufferizableOpInterfaceImpl.cpp @@ -179,10 +179,13 @@ struct CustomCallOpInterface // Add the initial number of arguments int32_t numArguments = static_cast(customCallOp.getNumOperands()); IntegerAttr numArgumentsAttr = rewriter.getI32IntegerAttr(numArguments); + auto newOp = CustomCallOp::create(rewriter, op->getLoc(), TypeRange{}, bufferArgs, + customCallOp.getCallTargetName(), numArgumentsAttr); - // Create an updated custom call operation - CustomCallOp::create(rewriter, op->getLoc(), TypeRange{}, bufferArgs, - customCallOp.getCallTargetName(), numArgumentsAttr); + // carry over any discardable attributes + for (NamedAttribute attr : op->getDiscardableAttrs()) { + newOp->setAttr(attr.getName(), attr.getValue()); + } size_t startIndex = bufferArgs.size() - customCallOp.getNumResults(); SmallVector bufferResults(bufferArgs.begin() + startIndex, bufferArgs.end()); bufferization::replaceOpWithBufferizedValues(rewriter, op, bufferResults); diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index 626b5f4a91..c5dc90072c 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -755,19 +755,29 @@ struct RemoteLibCallOpPattern : public OpConversionPattern { LLVM::LLVMFuncOp freeFn = catalyst::ensureFunctionDeclaration( rewriter, op, "__catalyst__remote__free_result", freeSig); - SmallVector offsets; - int64_t totalSize = 0; - for (Type ty : op.getOperandTypes()) { - int64_t n = primitiveByteSize(ty); - if (n < 0) { - return op->emitOpError("unsupported arg type for remote_lib_call: ") - << ty << " (supports int/float/index/complex only)"; + unsigned numInputs = op.getNumberOriginalArg().value_or(op.getNumOperands()); + if (numInputs > op.getNumOperands()) { + return op->emitOpError("number_original_arg exceeds operand count"); + } + + SmallVector inputOffsets; + int64_t totalInputBytes = 0; + for (unsigned i = 0; i < numInputs; ++i) { + int64_t argSize = memrefBufferBytes(op.getOperand(i).getType(), op); + if (argSize < 0) { + return failure(); + } + inputOffsets.push_back(totalInputBytes); + totalInputBytes += argSize; + } + for (unsigned i = numInputs; i < op.getNumOperands(); ++i) { + // only verify the result types + if (memrefBufferBytes(op.getOperand(i).getType(), op) < 0) { + return failure(); } - offsets.push_back(totalSize); - totalSize += n; } - Type bufTy = LLVM::LLVMArrayType::get(i8Ty, totalSize > 0 ? totalSize : 1); + Type bufTy = LLVM::LLVMArrayType::get(i8Ty, totalInputBytes > 0 ? totalInputBytes : 1); // Symbols Value addrPtr = getGlobalString(loc, rewriter, "remote_lib_addr_" + sym, @@ -776,14 +786,21 @@ struct RemoteLibCallOpPattern : public OpConversionPattern { // Alloca args buffer + store each arg. Value argsBuf = getStaticAlloca(loc, rewriter, bufTy, 1); - for (auto [llvmVal, off] : llvm::zip(adaptor.getOperands(), offsets)) { - Value slot = LLVM::GEPOp::create(rewriter, loc, ptrTy, bufTy, argsBuf, - ArrayRef{0, static_cast(off)}, - LLVM::GEPNoWrapFlags::inbounds); - LLVM::StoreOp::create(rewriter, loc, llvmVal, slot); + for (unsigned i = 0; i < numInputs; ++i) { + auto memrefTy = cast(op.getOperand(i).getType()); + int64_t numBytes = + memrefTy.getNumElements() * primitiveByteSize(memrefTy.getElementType()); + Value src = MemRefDescriptor(adaptor.getOperands()[i]).alignedPtr(rewriter, loc); + Value slot = LLVM::GEPOp::create( + rewriter, loc, ptrTy, bufTy, argsBuf, + ArrayRef{0, static_cast(inputOffsets[i])}, + LLVM::GEPNoWrapFlags::inbounds); + Value sizeVal = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numBytes)); + LLVM::MemcpyOp::create(rewriter, loc, slot, src, sizeVal, /*isVolatile=*/false); } Value argsSize = - LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(totalSize)); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(totalInputBytes)); // Alloca result buffer + size. Value outBufSlot = getStaticAlloca(loc, rewriter, ptrTy, 1); @@ -794,35 +811,51 @@ struct RemoteLibCallOpPattern : public OpConversionPattern { rewriter, loc, callFn, ValueRange{addrPtr, symPtr, argsBuf, argsSize, outBufSlot, outSizeSlot}); - // Decode return value (if any). - SmallVector returns; - Value outBuf; - if (!op.getResultTypes().empty()) { - if (op.getResultTypes().size() != 1) { - return op->emitOpError("remote_lib_call supports at most one result"); - } - Type retTy = op.getResultTypes().front(); - if (primitiveByteSize(retTy) < 0) { - return op->emitOpError("unsupported return type for remote_lib_call: ") << retTy; - } - Type retLLVMTy = getTypeConverter()->convertType(retTy); - outBuf = LLVM::LoadOp::create(rewriter, loc, ptrTy, outBufSlot); - Value rv = LLVM::LoadOp::create(rewriter, loc, retLLVMTy, outBuf); - returns.push_back(rv); - } - else { - outBuf = LLVM::LoadOp::create(rewriter, loc, ptrTy, outBufSlot); + Value outBuf = LLVM::LoadOp::create(rewriter, loc, ptrTy, outBufSlot); + int64_t outOffset = 0; + for (unsigned i = numInputs; i < op.getNumOperands(); ++i) { + auto memrefTy = cast(op.getOperand(i).getType()); + int64_t numBytes = + memrefTy.getNumElements() * primitiveByteSize(memrefTy.getElementType()); + Value destPtr = MemRefDescriptor(adaptor.getOperands()[i]).alignedPtr(rewriter, loc); + Value src = LLVM::GEPOp::create(rewriter, loc, ptrTy, i8Ty, outBuf, + ArrayRef{static_cast(outOffset)}, + LLVM::GEPNoWrapFlags::inbounds); + Value sizeVal = + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(numBytes)); + LLVM::MemcpyOp::create(rewriter, loc, destPtr, src, sizeVal, /*isVolatile=*/false); + outOffset += numBytes; } // Release the runtime-allocated result buffer. LLVM::CallOp::create(rewriter, loc, freeFn, ValueRange{outBuf}); - - rewriter.replaceOp(op, returns); + rewriter.eraseOp(op); return success(); } private: - // Supported scalar byte sizes. Returns -1 for unsupported types. + // Get the memref buffer bytes. + static int64_t memrefBufferBytes(Type ty, Operation *op) + { + auto memrefTy = dyn_cast(ty); + if (!memrefTy) { + op->emitOpError("remote_lib_call requires memref-typed operands; got ") << ty; + return -1; + } + if (!memrefTy.hasStaticShape()) { + op->emitOpError("remote_lib_call requires static-shape memref args; got ") << memrefTy; + return -1; + } + int64_t elemSz = primitiveByteSize(memrefTy.getElementType()); + if (elemSz < 0) { + op->emitOpError("unsupported memref element type for remote_lib_call: ") + << memrefTy.getElementType(); + return -1; + } + return memrefTy.getNumElements() * elemSz; + } + + // Supported element-type byte sizes. Returns -1 for unsupported types. static int64_t primitiveByteSize(Type ty) { if (auto i = dyn_cast(ty)) { diff --git a/runtime/lib/remote/RemoteRuntime.cpp b/runtime/lib/remote/RemoteRuntime.cpp index a97aec551a..a9c3cfedcf 100644 --- a/runtime/lib/remote/RemoteRuntime.cpp +++ b/runtime/lib/remote/RemoteRuntime.cpp @@ -32,6 +32,14 @@ struct RemoteEntry { std::set loaded_paths; // The paths of the binaries that are going to be loaded // into the remote session. std::mutex mu; // The mutex to protect the loaded paths. + + ~RemoteEntry() + { + if (session) { + catalyst::remote::close(session); + session = nullptr; + } + } }; std::mutex g_map_mu; From 9a842c16171135dc9c445a5ba847c30f7e0903b7 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 28 May 2026 09:05:41 -0400 Subject: [PATCH 3/6] rt_fail --- runtime/include/RemoteCAPI.h | 1 - runtime/lib/remote/CMakeLists.txt | 10 ++- runtime/lib/remote/RemoteRuntime.cpp | 91 ++++++++++++---------------- 3 files changed, 46 insertions(+), 56 deletions(-) diff --git a/runtime/include/RemoteCAPI.h b/runtime/include/RemoteCAPI.h index d6a778fe39..7cd69254f3 100644 --- a/runtime/include/RemoteCAPI.h +++ b/runtime/include/RemoteCAPI.h @@ -38,7 +38,6 @@ int __catalyst__remote__call_wrapper(const char *addr, const char *symbol, const void __catalyst__remote__free_result(void *buf); int __catalyst__remote__close(); -const char *__catalyst__remote__last_error(); #ifdef __cplusplus } // extern "C" diff --git a/runtime/lib/remote/CMakeLists.txt b/runtime/lib/remote/CMakeLists.txt index 37cd497abc..1030e54f26 100644 --- a/runtime/lib/remote/CMakeLists.txt +++ b/runtime/lib/remote/CMakeLists.txt @@ -10,8 +10,7 @@ llvm_map_components_to_libnames(_remote_llvm_libs ) # catalyst_remote_session - -add_library(catalyst_remote_session STATIC RemoteSession.cpp) +add_library(catalyst_remote_session SHARED RemoteSession.cpp) target_include_directories(catalyst_remote_session PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} @@ -22,6 +21,7 @@ target_include_directories(catalyst_remote_session SYSTEM PRIVATE ) target_link_libraries(catalyst_remote_session PRIVATE + rt_capi ${_remote_llvm_libs} pthread ${CMAKE_DL_LIBS} @@ -31,6 +31,12 @@ target_compile_options(catalyst_remote_session PRIVATE -fno-rtti) set_property(TARGET catalyst_remote_session PROPERTY POSITION_INDEPENDENT_CODE ON) +if(NOT APPLE) + set_property(TARGET catalyst_remote_session APPEND PROPERTY BUILD_RPATH $ORIGIN) +else() + set_property(TARGET catalyst_remote_session APPEND PROPERTY BUILD_RPATH @loader_path) +endif() + # rt_remote add_library(rt_remote SHARED RemoteRuntime.cpp) diff --git a/runtime/lib/remote/RemoteRuntime.cpp b/runtime/lib/remote/RemoteRuntime.cpp index a9c3cfedcf..55f662b75e 100644 --- a/runtime/lib/remote/RemoteRuntime.cpp +++ b/runtime/lib/remote/RemoteRuntime.cpp @@ -48,17 +48,6 @@ std::mutex g_map_mu; // sessions. std::map> remote_sessions; -thread_local std::string g_remote_runtime_error; -void set_remote_runtime_error(const char *msg) -{ - if (msg) { - g_remote_runtime_error = msg; - } - else { - g_remote_runtime_error = "(unknown)"; - } -} - // DEBUG logs bool remote_verbose() { @@ -94,48 +83,49 @@ extern "C" { int __catalyst__remote__open(const char *addr) { if (!addr || !*addr) { - set_remote_runtime_error("Empty address"); - return -1; + RT_FAIL("Empty address"); } RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/true); - std::lock_guard lock(entry->mu); - if (entry->session) { - return 0; // idempotent per addr - } - if (remote_verbose()) { - std::fprintf(stderr, "[remote] open(addr=%s)\n", addr); + std::string err_msg; + { + std::lock_guard lock(entry->mu); + if (entry->session) { + return 0; // idempotent per addr + } + if (remote_verbose()) { + std::fprintf(stderr, "[remote] open(addr=%s)\n", addr); + } + entry->session = catalyst::remote::open(addr); + if (entry->session) { + if (remote_verbose()) { + std::fprintf(stderr, "[remote] open(%s) OK\n", addr); + } + return 0; + } + err_msg = "Could not connect to catalyst-executor at "; + err_msg += addr; + err_msg += ": "; + err_msg += catalyst::remote::last_error(); } - entry->session = catalyst::remote::open(addr); - if (!entry->session) { - std::string msg = "Could not connect to catalyst-executor at "; - msg += addr; - msg += ": "; - msg += catalyst::remote::last_error(); - set_remote_runtime_error(msg.c_str()); + { std::lock_guard mapLock(g_map_mu); remote_sessions.erase(addr); - return -1; - } - if (remote_verbose()) { - std::fprintf(stderr, "[remote] open(%s) OK\n", addr); } - return 0; + RT_FAIL(err_msg.c_str()); } int __catalyst__remote__send_binary(const char *addr, const char *path, uint32_t format) { RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/false); if (!entry) { - set_remote_runtime_error("No session found, call __catalyst__remote__open first."); - return -1; + RT_FAIL("No session found, call __catalyst__remote__open first."); } std::lock_guard lock(entry->mu); if (!entry->session) { std::string msg = "__catalyst__remote__send_binary("; msg += addr; msg += "): session is closed."; - set_remote_runtime_error(msg.c_str()); - return -1; + RT_FAIL(msg.c_str()); } if (!path || !*path) { return 0; @@ -157,24 +147,25 @@ int __catalyst__remote__send_binary(const char *addr, const char *path, uint32_t case 1: rc = catalyst::remote::load_asset_path(entry->session, path); break; - default: + default: { std::string msg = "unknown binary format tag "; msg += std::to_string(format); - set_remote_runtime_error(msg.c_str()); - rc = -1; + entry->loaded_paths.erase(key); + RT_FAIL(msg.c_str()); + } } if (rc != 0) { - set_remote_runtime_error(catalyst::remote::last_error()); + std::string msg = catalyst::remote::last_error(); entry->loaded_paths.erase(key); - return -1; + RT_FAIL(msg.c_str()); } return 0; } /** * @brief Generic ORC wrapper-function call by symbol name. Returns 0 on success, -1 on error. - * + * * @param addr The address of the remote session. * @param symbol The symbol of the function to call. * @param args_buf The buffer of the arguments. @@ -194,17 +185,14 @@ int __catalyst__remote__call_wrapper(const char *addr, const char *symbol, const } RemoteEntry *entry = find_or_create_entry(addr, /*create_if_missing=*/false); if (!entry) { - set_remote_runtime_error("No session found, call __catalyst__remote__open first."); - return -1; + RT_FAIL("No session found, call __catalyst__remote__open first."); } std::lock_guard lock(entry->mu); if (!entry->session) { - set_remote_runtime_error("Session is closed"); - return -1; + RT_FAIL("Session is closed"); } if (!symbol || !*symbol) { - set_remote_runtime_error("Empty symbol passed to __catalyst__remote__call_wrapper"); - return -1; + RT_FAIL("Empty symbol passed to __catalyst__remote__call_wrapper"); } if (remote_verbose()) { std::fprintf(stderr, "[remote] call_wrapper(addr=%s, sym=%s, in_size=%zu)\n", addr, symbol, @@ -212,11 +200,10 @@ int __catalyst__remote__call_wrapper(const char *addr, const char *symbol, const } char *buf = nullptr; size_t n = 0; - int rc = catalyst::remote::call_wrapper_raw(entry->session, symbol, args_buf, args_size, &buf, - &n); + int rc = + catalyst::remote::call_wrapper_raw(entry->session, symbol, args_buf, args_size, &buf, &n); if (rc != 0) { - set_remote_runtime_error(catalyst::remote::last_error()); - return -1; + RT_FAIL(catalyst::remote::last_error()); } if (out_buf) { *out_buf = buf; @@ -247,8 +234,6 @@ int __catalyst__remote__close() return 0; } -const char *__catalyst__remote__last_error() { return g_remote_runtime_error.c_str(); } - void __catalyst__remote__launch(const char *addr, const char *entry_symbol, size_t num_inputs, void *const *input_descs, const size_t *input_ranks, const size_t *input_elem_sizes, size_t num_outputs, From 5d23da0cb9b44064bc99ae3f4db182d5011f8679 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 28 May 2026 09:09:52 -0400 Subject: [PATCH 4/6] no translator needed --- frontend/catalyst/utils/wrapper.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/frontend/catalyst/utils/wrapper.cpp b/frontend/catalyst/utils/wrapper.cpp index 1f0ae3ba9c..94788ef694 100644 --- a/frontend/catalyst/utils/wrapper.cpp +++ b/frontend/catalyst/utils/wrapper.cpp @@ -22,8 +22,6 @@ #include "numpy/ndarrayobject.h" -#include "Exception.hpp" - namespace nb = nanobind; struct memref_beginning_t { @@ -236,22 +234,6 @@ NB_MODULE(wrapper, m) { m.doc() = "wrapper module"; - nb::register_exception_translator([](const std::exception_ptr &p, void * /*payload*/) { - try { - std::rethrow_exception(p); - } - catch (const Catalyst::Runtime::RuntimeException &e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - } - catch (const std::exception &e) { - PyErr_SetString(PyExc_RuntimeError, e.what()); - } - catch (...) { - PyErr_SetString(PyExc_RuntimeError, - "unknown C++ exception caught by wrapper translator"); - } - }); - // We have to annotate all the arguments to `wrap` to allow `result_desc` to be None // See https://nanobind.readthedocs.io/en/latest/functions.html#none-arguments m.def("wrap", &wrap, "A wrapper function.", nb::arg("func"), nb::arg("py_args"), From 4f7eca455ed55aa3a48a3a7fd4efd80d3f7d3c5f Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 1 Jun 2026 11:40:27 -0400 Subject: [PATCH 5/6] add remote address if remote lib call only --- .../Transforms/CrossCompileRemoteKernels.cpp | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp b/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp index 9eab453a01..310623552f 100644 --- a/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp +++ b/mlir/lib/Catalyst/Transforms/CrossCompileRemoteKernels.cpp @@ -95,21 +95,31 @@ struct CrossCompileRemoteKernelsPass return fn->hasAttr("qnode") && !fn->hasAttr("catalyst.remote_kernel"); }); - if (qnodes.empty()) { + bool hasPluginCalls = false; + host.walk([&](catalyst::CustomCallOp call) { + if (call.getCallTargetName() == "remote_lib_call") { + hasPluginCalls = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (qnodes.empty() && !hasPluginCalls) { return; } - if (workspace.empty()) { - host.emitError("Missing `workspace` option for remote kernel cross-compilation"); - return signalPassFailure(); - } + if (!qnodes.empty()) { + if (workspace.empty()) { + host.emitError("Missing `workspace` option for remote kernel cross-compilation"); + return signalPassFailure(); + } - // For cross-compilation, we need to initialize the LLVM target registry. - llvm::InitializeAllTargetInfos(); - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - llvm::InitializeAllAsmParsers(); - llvm::InitializeAllAsmPrinters(); + llvm::InitializeAllTargetInfos(); + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + llvm::InitializeAllAsmParsers(); + llvm::InitializeAllAsmPrinters(); + } injectRemoteOpenIntoSetup(host); @@ -254,8 +264,8 @@ struct CrossCompileRemoteKernelsPass std::unique_ptr targetMachine(llvmTarget->createTargetMachine( parsedTriple, cpu, features, opt, llvm::Reloc::Model::PIC_)); if (!targetMachine) { - llvm::errs() << "Could not create TargetMachine for triple '" << target - << "' cpu='" << cpu << "' features='" << features << "'\n"; + llvm::errs() << "Could not create TargetMachine for triple '" << target << "' cpu='" + << cpu << "' features='" << features << "'\n"; return ""; } From 822b5037f8cf4216a6ce0be6f038fc98278d4d15 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Mon, 1 Jun 2026 16:01:52 -0400 Subject: [PATCH 6/6] remove reudundant code --- frontend/catalyst/utils/CMakeLists.txt | 5 ----- frontend/catalyst/utils/wrapper.cpp | 1 - 2 files changed, 6 deletions(-) diff --git a/frontend/catalyst/utils/CMakeLists.txt b/frontend/catalyst/utils/CMakeLists.txt index 5566d4ebaf..51ec6d2350 100644 --- a/frontend/catalyst/utils/CMakeLists.txt +++ b/frontend/catalyst/utils/CMakeLists.txt @@ -38,11 +38,6 @@ nanobind_add_module(wrapper STABLE_ABI ${WRAPPER_SRC_FILES}) # Add the NumPy include directory to the library's include paths target_include_directories(wrapper PRIVATE ${Python_NumPy_INCLUDE_DIRS}) -# Catalyst runtime headers (for `Catalyst::Runtime::RuntimeException`). -target_include_directories(wrapper PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/../../../runtime/include -) - # Use suffix ".so" rather than ".abi3.so" for library file using Stable ABI # This is necessary for compatibility with setuptools build extensions set_target_properties(wrapper PROPERTIES SUFFIX ".so") diff --git a/frontend/catalyst/utils/wrapper.cpp b/frontend/catalyst/utils/wrapper.cpp index 94788ef694..f9e5c29b8f 100644 --- a/frontend/catalyst/utils/wrapper.cpp +++ b/frontend/catalyst/utils/wrapper.cpp @@ -233,7 +233,6 @@ nb::list wrap(nb::object func, nb::tuple py_args, nb::object result_desc, nb::ob NB_MODULE(wrapper, m) { m.doc() = "wrapper module"; - // We have to annotate all the arguments to `wrap` to allow `result_desc` to be None // See https://nanobind.readthedocs.io/en/latest/functions.html#none-arguments m.def("wrap", &wrap, "A wrapper function.", nb::arg("func"), nb::arg("py_args"),