Skip to content

Commit 0299a12

Browse files
authored
Slim down CompiledModule, move out execution logic (#4356)
This is the second follow up after #4322 from our discussion last week. This PR 'dumbs down' the `CompiledModule` type in the following sense - it does not make any assumptions anymore on what the compilation artifacts contain (e.g. before `JitArtifact` kept track of the various named entrypoints according to some naming convention) - it does not know what 'executing' a `CompiledModule` means The result is a pretty clean 'container' type that just knows how to store artifacts of different kinds. The only methods that still rely on some naming conventions are `getArgsCreator` and `getReturnOffset`, which exist purely for convenience (can be removed if you wish). This logic (which is dependent on conventions), now lives outside of the type. The naming conventions and entrypoint resolution was fully moved to construction (in `runtime/internal/compiler/include/cudaq_internal/compiler/CompiledModuleHelper.h`). The execution logic on the other hand was moved to the only placed it is currently used in (`runtime/cudaq/platform/qpu.cpp`). We might need to move this further as we unify kernel execution for C++ and Python, but wasn't sure where it would eventually land (and in what form). Signed-off-by: Luca Mondada <luca@mondada.net>
1 parent ccbf2a8 commit 0299a12

7 files changed

Lines changed: 205 additions & 238 deletions

File tree

python/runtime/cudaq/platform/py_alt_launch_kernel.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,8 +1143,7 @@ void cudaq::bindAltLaunchKernel(nanobind::module_ &mod,
11431143
.def_prop_ro(
11441144
"entry_point",
11451145
[](const cudaq::CompiledModule &ck) {
1146-
return reinterpret_cast<std::uintptr_t>(
1147-
ck.getJit().getEntryPoint());
1146+
return reinterpret_cast<std::uintptr_t>(ck.getJit()->getFn());
11481147
},
11491148
"The address of the JIT-compiled entry point.")
11501149
.def_prop_ro("is_fully_specialized",

runtime/common/CompiledModule.cpp

Lines changed: 40 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,58 @@
77
******************************************************************************/
88

99
#include "CompiledModule.h"
10-
#include <cstring>
11-
#include <memory>
1210
#include <stdexcept>
1311

1412
cudaq::CompiledModule::CompiledModule(std::string kernelName)
1513
: name(std::move(kernelName)) {}
1614

17-
const cudaq::CompiledModule::JitArtifact &
18-
cudaq::CompiledModule::getJit() const {
19-
for (auto &[key, artifact] : artifacts)
20-
if (auto *jit = std::get_if<JitArtifact>(&artifact))
21-
return *jit;
22-
throw std::runtime_error("CompiledModule has no JIT artifact.");
15+
std::optional<cudaq::CompiledModule::JitArtifact>
16+
cudaq::CompiledModule::getJit(std::optional<std::string> jitName) const {
17+
auto name = jitName.value_or(this->name);
18+
auto it = artifacts.find(name);
19+
if (it == artifacts.end())
20+
return std::nullopt;
21+
const auto *jit = std::get_if<JitArtifact>(&it->second);
22+
return jit ? std::optional(*jit) : std::nullopt;
2323
}
2424

25-
const cudaq::CompiledModule::MlirArtifact &
26-
cudaq::CompiledModule::getMlir() const {
27-
for (auto &[key, artifact] : artifacts)
28-
if (auto *mlir = std::get_if<MlirArtifact>(&artifact))
29-
return *mlir;
30-
throw std::runtime_error("CompiledModule has no MLIR artifact.");
25+
std::optional<cudaq::CompiledModule::MlirArtifact>
26+
cudaq::CompiledModule::getMlir(std::optional<std::string> mlirName) const {
27+
auto name = mlirName.value_or(this->name + ".mlir");
28+
auto it = artifacts.find(name);
29+
if (it == artifacts.end())
30+
return std::nullopt;
31+
const auto *mlir = std::get_if<MlirArtifact>(&it->second);
32+
return mlir ? std::optional(*mlir) : std::nullopt;
3133
}
3234

33-
bool cudaq::CompiledModule::hasJit() const {
34-
for (auto &[key, artifact] : artifacts)
35-
if (std::holds_alternative<JitArtifact>(artifact))
36-
return true;
37-
return false;
35+
bool cudaq::CompiledModule::isFullySpecialized() const {
36+
return getArgsCreator() == nullptr;
3837
}
3938

40-
bool cudaq::CompiledModule::hasMlir() const {
41-
for (auto &[key, artifact] : artifacts)
42-
if (std::holds_alternative<MlirArtifact>(artifact))
43-
return true;
44-
return false;
39+
int64_t (*cudaq::CompiledModule::getArgsCreator() const)(const void *,
40+
void **) {
41+
auto jit = getJit(name + ".argsCreator");
42+
return jit ? reinterpret_cast<int64_t (*)(const void *, void **)>(jit->fn)
43+
: nullptr;
4544
}
4645

47-
bool cudaq::CompiledModule::isFullySpecialized() const {
48-
if (!hasJit())
49-
return true; // No JIT artifact → fully specialized.
50-
return getJit().argsCreator == nullptr;
46+
std::optional<std::int64_t> cudaq::CompiledModule::getReturnOffset() const {
47+
auto jit = getJit(name + ".returnOffset");
48+
if (!jit)
49+
return std::nullopt;
50+
auto fn = reinterpret_cast<std::int64_t (*)()>(jit->fn);
51+
return fn();
52+
}
53+
54+
const cudaq::Resources *cudaq::CompiledModule::getResources(
55+
std::optional<std::string> resourcesName) const {
56+
auto name = resourcesName.value_or(this->name + ".resources");
57+
auto it = artifacts.find(name);
58+
if (it == artifacts.end())
59+
return nullptr;
60+
const auto *resources = std::get_if<ResourcesArtifact>(&it->second);
61+
return resources ? &resources->getResources() : nullptr;
5162
}
5263

5364
void cudaq::CompiledModule::addArtifact(std::string name,
@@ -57,62 +68,8 @@ void cudaq::CompiledModule::addArtifact(std::string name,
5768
artifacts.emplace(std::move(name), std::move(artifact));
5869
}
5970

60-
cudaq::KernelThunkResultType
61-
cudaq::CompiledModule::execute(const std::vector<void *> &rawArgs) const {
62-
auto &jit = getJit();
63-
auto funcPtr = jit.entryPoint;
64-
if (!isFullySpecialized()) {
65-
// Pack args at runtime via argsCreator, then call the thunk.
66-
void *buff = nullptr;
67-
jit.argsCreator(static_cast<const void *>(rawArgs.data()), &buff);
68-
reinterpret_cast<KernelThunkResultType (*)(void *, bool)>(funcPtr)(
69-
buff, /*client_server=*/false);
70-
// If the kernel has a result, copy it from the packed buffer into
71-
// rawArgs.back() (where the caller expects to find it).
72-
if (resultInfo.hasResult()) {
73-
auto offset = jit.returnOffset();
74-
std::memcpy(rawArgs.back(), static_cast<char *>(buff) + offset,
75-
resultInfo.bufferSize);
76-
}
77-
std::free(buff);
78-
return {nullptr, 0};
79-
}
80-
if (resultInfo.hasResult()) {
81-
// Fully specialized with result: rawArgs.back() is the pre-allocated
82-
// result buffer; pass it directly to the thunk.
83-
void *buff = const_cast<void *>(rawArgs.back());
84-
return reinterpret_cast<KernelThunkResultType (*)(void *, bool)>(funcPtr)(
85-
buff, /*client_server=*/false);
86-
}
87-
// Fully specialized, no result.
88-
jit.entryPoint();
89-
return {nullptr, 0};
90-
}
91-
92-
cudaq::KernelThunkResultType cudaq::CompiledModule::execute() const {
93-
if (!isFullySpecialized())
94-
throw std::runtime_error(
95-
"Kernel has unspecialized parameters; call execute(rawArgs) instead.");
96-
if (!resultInfo.hasResult()) {
97-
getJit().entryPoint();
98-
return {nullptr, 0};
99-
}
100-
// Allocate a result buffer on-the-fly.
101-
auto buf = std::make_unique<char[]>(resultInfo.bufferSize);
102-
std::vector<void *> rawArgs = {buf.get()};
103-
execute(rawArgs);
104-
return {buf.release(), resultInfo.bufferSize};
105-
}
106-
107-
void (*cudaq::CompiledModule::JitArtifact::getEntryPoint() const)() {
108-
return entryPoint;
109-
}
71+
void (*cudaq::CompiledModule::JitArtifact::getFn() const)() { return fn; }
11072

11173
cudaq::JitEngine cudaq::CompiledModule::JitArtifact::getEngine() const {
11274
return engine;
11375
}
114-
115-
std::optional<cudaq::Resources>
116-
cudaq::CompiledModule::JitArtifact::getResourceCounts() const {
117-
return resourceCounts;
118-
}

runtime/common/CompiledModule.h

Lines changed: 56 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ class ResultInfo {
8787
public:
8888
/// Whether this kernel has a result that must be marshaled.
8989
bool hasResult() const { return typeOpaquePtr != nullptr; }
90+
/// Get the size (in bytes) of the buffer needed to hold the result value.
91+
std::size_t getBufferSize() const { return bufferSize; }
9092
};
9193

9294
/// @brief A compiled MLIR module, ready for execution or code generation.
9395
///
9496
/// Contains any number of named compilation artifacts (we currently support
95-
/// JIT binaries and optimized MLIR modules) that result from the compilation
96-
/// of a Quake MLIR module.
97+
/// JIT binaries, optimized MLIR modules, and pre-computed resource metrics)
98+
/// that result from the compilation of a Quake MLIR module.
9799
///
98100
/// This type does not depend on MLIR/LLVM — it only keeps type-erased / opaque
99101
/// pointers. Build instances with
@@ -105,46 +107,18 @@ class CompiledModule {
105107
/// JIT-compiled artifact, ready for local execution.
106108
class JitArtifact {
107109
JitEngine engine;
108-
void (*entryPoint)() = nullptr;
109-
std::int64_t (*argsCreator)(const void *, void **) = nullptr;
110-
/// Offset (in bytes) of the result field within the argsCreator-packed
111-
/// buffer. Only valid when argsCreator is non-null and the kernel has a
112-
/// result. Use resultInfo.bufferSize to know how many bytes to copy.
113-
std::int64_t (*returnOffset)() = nullptr;
114-
std::optional<Resources> resourceCounts;
115-
116-
JitArtifact(JitEngine engine, void (*entryPoint)(),
117-
int64_t (*argsCreator)(const void *, void **),
118-
int64_t (*returnOffset)(),
119-
std::optional<Resources> resourceCounts)
120-
: engine(engine), entryPoint(entryPoint), argsCreator(argsCreator),
121-
returnOffset(returnOffset),
122-
resourceCounts(std::move(resourceCounts)) {}
110+
void (*fn)() = nullptr;
111+
112+
JitArtifact(JitEngine engine, void (*fn)())
113+
: engine(std::move(engine)), fn(fn) {}
123114

124115
friend class CompiledModule;
125116
friend class cudaq_internal::compiler::CompiledModuleHelper;
126117

127118
public:
128-
// TODO: remove the following two methods once the `CompiledModule` instance
129-
// is returned to Python.
130-
131-
/// @brief Get the entry point of the kernel as a function pointer.
132-
///
133-
/// Assumes that there is (exactly one) compiled JIT artifact.
134-
///
135-
/// The returned function pointer will expect different arguments depending
136-
/// on the kernel:
137-
/// - if the kernel returns a value and/or is not fully specialized, the
138-
/// entry point will expect a pointer to a buffer storing the packed
139-
/// arguments and result.
140-
/// - otherwise, the entry point will not expect any arguments.
141-
///
142-
/// Prefer using `CompiledModule::execute` instead of calling this function
143-
/// as it will handle the buffer and argument packing automatically.
144-
void (*getEntryPoint() const)();
119+
/// Get the raw function pointer stored in this artifact.
120+
void (*getFn() const)();
145121
JitEngine getEngine() const;
146-
147-
std::optional<Resources> getResourceCounts() const;
148122
};
149123

150124
/// Optimized MLIR module artifact, for deferred code generation or
@@ -165,8 +139,22 @@ class CompiledModule {
165139
friend class cudaq_internal::compiler::CompiledModuleHelper;
166140
};
167141

168-
/// A compiled artifact is either a JIT binary or an MLIR module.
169-
using CompiledArtifact = std::variant<JitArtifact, MlirArtifact>;
142+
/// Pre-computed resource metrics (gate counts, depth) from IR analysis.
143+
class ResourcesArtifact {
144+
Resources resources;
145+
146+
ResourcesArtifact(Resources resources) : resources(std::move(resources)) {}
147+
148+
friend class CompiledModule;
149+
friend class cudaq_internal::compiler::CompiledModuleHelper;
150+
151+
public:
152+
const Resources &getResources() const { return resources; }
153+
};
154+
155+
/// A compiled artifact is a JIT binary, an MLIR module, or resource metrics.
156+
using CompiledArtifact =
157+
std::variant<JitArtifact, MlirArtifact, ResourcesArtifact>;
170158

171159
// --- Compilation metadata ---
172160

@@ -178,47 +166,51 @@ class CompiledModule {
178166

179167
// --- Queries ---
180168

181-
/// Whether any artifact in the map is a JitArtifact.
182-
bool hasJit() const;
183-
184-
/// Whether any artifact in the map is an MlirArtifact.
185-
bool hasMlir() const;
169+
/// Get the JIT artifact with the given name.
170+
///
171+
/// If no name is provided, defaults to the kernel name.
172+
std::optional<JitArtifact>
173+
getJit(std::optional<std::string> jitName = std::nullopt) const;
186174

187-
/// Get the compiled JIT artifact. Returns the first one found.
175+
/// Get the MLIR artifact with the given name.
188176
///
189-
/// Throws if none exists.
190-
const JitArtifact &getJit() const;
177+
/// If no name is provided, defaults to `kernel_name + ".mlir"`.
178+
std::optional<MlirArtifact>
179+
getMlir(std::optional<std::string> mlirName = std::nullopt) const;
191180

192-
/// Get the optimized MLIR artifact. Returns the first one found.
181+
/// Get the pre-computed resource counts, or `nullptr` if it does not exist.
193182
///
194-
/// Throws if none exists.
195-
const MlirArtifact &getMlir() const;
183+
/// If no name is provided, defaults to `kernel_name + ".resources"`.
184+
const Resources *
185+
getResources(std::optional<std::string> resourcesName = std::nullopt) const;
196186

197187
/// Get all compiled artifacts.
198188
const std::map<std::string, CompiledArtifact> &getArtifacts() const {
199189
return artifacts;
200190
}
201191

202-
/// Whether the kernel is fully specialized (all arguments inlined). For JIT
203-
/// kernels this means `argsCreator` is null.
204-
/// Kernels without a JIT artifact are considered fully specialized.
192+
/// Whether the kernel is fully specialized (all arguments inlined).
193+
///
194+
/// Currently, kernels are considered fully specialized if and only if they do
195+
/// not have an `argsCreator` artifact.
205196
bool isFullySpecialized() const;
206197

207-
const std::string &getName() const { return name; }
208-
const ResultInfo &getResultInfo() const { return resultInfo; }
209-
const CompilationMetadata &getMetadata() const { return metadata; }
210-
211-
// --- Execution (local JIT path) ---
212-
213-
/// @brief Execute a fully specialized kernel (no external arguments needed).
198+
/// Get the argument-marshaling function, or `nullptr` if it does not exist.
214199
///
215-
/// Assumes that there is (exactly one) compiled JIT artifact.
216-
KernelThunkResultType execute() const;
200+
/// Assumes the artifact is named `kernelName + ".argsCreator"`.
201+
int64_t (*getArgsCreator() const)(const void *, void **);
217202

218-
/// @brief Execute the JIT-ed kernel with caller-provided arguments.
203+
/// Get the offset (in bytes) of the result field within the
204+
/// `argsCreator`-packed buffer, evaluating the stored JIT function.
205+
/// Returns `std::nullopt` if no `.returnOffset` artifact was emitted
206+
/// (e.g. the kernel has no result or is fully specialized).
219207
///
220-
/// Assumes that there is (exactly one) compiled JIT artifact.
221-
KernelThunkResultType execute(const std::vector<void *> &rawArgs) const;
208+
/// Assumes the artifact is named `kernelName + ".returnOffset"`.
209+
std::optional<std::int64_t> getReturnOffset() const;
210+
211+
const std::string &getName() const { return name; }
212+
const ResultInfo &getResultInfo() const { return resultInfo; }
213+
const CompilationMetadata &getMetadata() const { return metadata; }
222214

223215
private:
224216
friend class cudaq_internal::compiler::CompiledModuleHelper;

runtime/cudaq/platform/qpu.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,50 @@
88

99
#include "qpu.h"
1010
#include "mlir/IR/BuiltinOps.h"
11+
#include <cstring>
1112

1213
using namespace cudaq_internal::compiler;
1314

1415
LLVM_INSTANTIATE_REGISTRY(cudaq::ModuleLauncher::RegistryType)
1516

17+
/// Execute a JIT-compiled kernel with provided arguments.
18+
///
19+
/// Handles argument marshaling via `argsCreator` (if not fully specialized) and
20+
/// result buffer allocation.
21+
cudaq::KernelThunkResultType
22+
launchCompiledModule(const cudaq::CompiledModule &compiled,
23+
const std::vector<void *> &rawArgs) {
24+
auto funcPtr = compiled.getJit()->getFn();
25+
const auto &resultInfo = compiled.getResultInfo();
26+
if (!compiled.isFullySpecialized()) {
27+
// Pack args at runtime via argsCreator, then call the thunk.
28+
auto argsCreator = compiled.getArgsCreator();
29+
void *buff = nullptr;
30+
argsCreator(static_cast<const void *>(rawArgs.data()), &buff);
31+
reinterpret_cast<cudaq::KernelThunkResultType (*)(void *, bool)>(funcPtr)(
32+
buff, /*client_server=*/false);
33+
// If the kernel has a result, copy it from the packed buffer into
34+
// rawArgs.back() (where the caller expects to find it).
35+
if (resultInfo.hasResult()) {
36+
auto offset = compiled.getReturnOffset().value();
37+
std::memcpy(rawArgs.back(), static_cast<char *>(buff) + offset,
38+
resultInfo.getBufferSize());
39+
}
40+
std::free(buff);
41+
return {nullptr, 0};
42+
}
43+
if (resultInfo.hasResult()) {
44+
// Fully specialized with result: rawArgs.back() is the pre-allocated
45+
// result buffer; pass it directly to the thunk.
46+
void *buff = const_cast<void *>(rawArgs.back());
47+
return reinterpret_cast<cudaq::KernelThunkResultType (*)(void *, bool)>(
48+
funcPtr)(buff, /*client_server=*/false);
49+
}
50+
// Fully specialized, no result.
51+
funcPtr();
52+
return {nullptr, 0};
53+
}
54+
1655
cudaq::KernelThunkResultType
1756
cudaq::QPU::launchModule(const std::string &name, mlir::ModuleOp module,
1857
const std::vector<void *> &rawArgs) {
@@ -23,7 +62,7 @@ cudaq::QPU::launchModule(const std::string &name, mlir::ModuleOp module,
2362
"result of attempting to use `launchModule` outside Python.");
2463
ScopedTraceWithContext(cudaq::TIMING_LAUNCH, "QPU::launchModule", name);
2564
auto compiled = launcher->compileModule(name, module, rawArgs, true);
26-
return compiled.execute(rawArgs);
65+
return launchCompiledModule(compiled, rawArgs);
2766
}
2867

2968
cudaq::CompiledModule

0 commit comments

Comments
 (0)