Skip to content

Commit cbc5cc6

Browse files
michalharakalclaude
andcommitted
cleanup(gpu): delete GPU stubs and migrate native benchmark to DSL
Removes the placeholder GPU code paths in :llm-runtime:kllama and the native benchmark engine. There is no real GPU support in this repo — GpuAttentionBackend, GpuTensorBridge, and the createGpuBridge / createMetalContext / createMlxContext expect/actual chains were stubs that always fell back to CPU. - Delete GpuAttentionBackend.kt and GpuTensorBridge.kt - Strip createGpuTensorBridge / createGraphAccelerator from kllama BackendExpect.kt and the linux/macos/ios actuals (createGraphAccelerator was unused dead code) - Drop createMetalContext / createMlxContext / createGpuBridge from llm-performance macosMain; only availableNativeBackends remains - Rewrite NativeBenchmarkEngine: drop GpuNativeLlamaAdapter and the Metal/MLX scenario adapters; rename scenario to native-cpu-throughput and migrate the CPU adapter to the DSL path (DecoderGgufWeightLoader + LlamaNetworkLoader.fromWeights + OptimizedLLMRuntime DIRECT), mirroring #127's JVM cleanup - Drop GpuAttentionBackend reference from AttentionBackend kdoc Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6662c35 commit cbc5cc6

9 files changed

Lines changed: 73 additions & 479 deletions

File tree

llm-inference/llama/src/commonMain/kotlin/sk/ainet/models/llama/AttentionBackend.kt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import sk.ainet.lang.types.DType
88
*
99
* Encapsulates the divergent part of transformer layer execution:
1010
* RoPE encoding, KV cache management, and attention scoring.
11-
* Two implementations exist: CPU-based (CpuAttentionBackend) and
12-
* GPU-native (GpuAttentionBackend).
11+
* The current production implementation is CPU-based (CpuAttentionBackend).
1312
*
1413
* Contract:
1514
* - Input: q [1, dim], k [1, kvDim], v [1, kvDim], layerIdx, position
Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
11
package sk.ainet.performance.native
22

3-
import sk.ainet.apps.kllama.GpuTensorBridge
4-
import sk.ainet.context.ExecutionContext
5-
import sk.ainet.lang.types.DType
6-
7-
internal actual fun createMetalContext(): ExecutionContext? = null
8-
9-
internal actual fun createMlxContext(): ExecutionContext? = null
10-
11-
internal actual fun <T : DType> createGpuBridge(ctx: ExecutionContext): GpuTensorBridge<T>? = null
12-
133
internal actual fun availableNativeBackends(): List<String> = listOf("CPU")

llm-performance/src/nativeMain/kotlin/sk/ainet/performance/native/NativeBenchmarkEngine.kt

Lines changed: 72 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,15 @@ import kotlinx.io.buffered
44
import kotlinx.io.files.Path
55
import kotlinx.io.files.SystemFileSystem
66
import kotlin.time.measureTime
7-
import sk.ainet.apps.kllama.CpuAttentionBackend
87
import sk.ainet.apps.kllama.GGUFTokenizer
9-
import sk.ainet.apps.kllama.GpuAttentionBackend
10-
import sk.ainet.apps.kllama.GpuTensorBridge
11-
import sk.ainet.apps.kllama.LlamaIngestion
12-
import sk.ainet.apps.kllama.LlamaLoadConfig
8+
import sk.ainet.apps.llm.OptimizedLLMMode
9+
import sk.ainet.apps.llm.OptimizedLLMRuntime
10+
import sk.ainet.apps.llm.generate
1311
import sk.ainet.context.DirectCpuExecutionContext
14-
import sk.ainet.context.ExecutionContext
1512
import sk.ainet.io.model.QuantPolicy
16-
import sk.ainet.lang.types.DType
1713
import sk.ainet.lang.types.FP32
18-
import sk.ainet.models.llama.LlamaRuntime
19-
import sk.ainet.models.llama.LlamaRuntimeWeights
14+
import sk.ainet.models.llama.DecoderGgufWeightLoader
15+
import sk.ainet.models.llama.LlamaNetworkLoader
2016
import sk.ainet.performance.BenchmarkCaseResult
2117
import sk.ainet.performance.BenchmarkCaseStatus
2218
import sk.ainet.performance.BenchmarkMetric
@@ -49,15 +45,8 @@ private fun formatDouble1(value: Double): String {
4945
return "$intPart.$fracPart"
5046
}
5147

52-
// ── Expect declarations for macOS-specific backend creation ──
53-
54-
internal expect fun createMetalContext(): ExecutionContext?
55-
internal expect fun createMlxContext(): ExecutionContext?
56-
internal expect fun <T : DType> createGpuBridge(ctx: ExecutionContext): GpuTensorBridge<T>?
5748
internal expect fun availableNativeBackends(): List<String>
5849

59-
// ── Data structures ──
60-
6150
internal data class NamedPrompt(
6251
val label: String,
6352
val text: String,
@@ -68,188 +57,78 @@ internal data class PromptPlan(
6857
val promptTokens: IntArray,
6958
)
7059

71-
// ── Adapter interface ──
72-
73-
internal interface NativeLlamaAdapter {
74-
val runtimeName: String
75-
76-
suspend fun runAllCases(
77-
promptPlans: List<PromptPlan>,
78-
stepCounts: List<Int>,
79-
warmupRuns: Int,
80-
measuredRuns: Int,
81-
): List<BenchmarkCaseResult>
82-
}
83-
84-
// ── CPU adapter ──
85-
86-
internal class CpuNativeLlamaAdapter(
60+
internal class CpuNativeDslAdapter(
8761
private val modelPathStr: String,
88-
) : NativeLlamaAdapter {
89-
override val runtimeName: String = "CPU"
62+
) {
63+
val runtimeName: String = "CPU"
9064

91-
override suspend fun runAllCases(
65+
suspend fun runAllCases(
9266
promptPlans: List<PromptPlan>,
9367
stepCounts: List<Int>,
9468
warmupRuns: Int,
9569
measuredRuns: Int,
9670
): List<BenchmarkCaseResult> {
9771
val ctx = DirectCpuExecutionContext()
72+
val modelPath = Path(modelPathStr)
9873
log(" $runtimeName | loading model...")
99-
val weights = loadWeights<FP32>(ctx, FP32::class, modelPathStr)
100-
val backend = CpuAttentionBackend<FP32>(ctx, weights, FP32::class)
101-
@Suppress("DEPRECATION")
102-
val runtime = LlamaRuntime<FP32>(ctx, weights, backend, FP32::class)
103-
log(" $runtimeName | model loaded")
104-
105-
return benchmarkCases(runtimeName, runtime, promptPlans, stepCounts, warmupRuns, measuredRuns)
106-
}
107-
}
108-
109-
// ── GPU adapter (Metal or MLX) ──
110-
111-
internal class GpuNativeLlamaAdapter(
112-
private val modelPathStr: String,
113-
override val runtimeName: String,
114-
private val contextFactory: () -> ExecutionContext?,
115-
) : NativeLlamaAdapter {
116-
117-
override suspend fun runAllCases(
118-
promptPlans: List<PromptPlan>,
119-
stepCounts: List<Int>,
120-
warmupRuns: Int,
121-
measuredRuns: Int,
122-
): List<BenchmarkCaseResult> {
123-
val ctx = try {
124-
contextFactory()
125-
} catch (e: Exception) {
126-
log(" $runtimeName | failed to create context: ${e.message}")
127-
null
128-
}
129-
130-
if (ctx == null) {
131-
log(" $runtimeName | backend unavailable — skipping")
132-
return skipAll(promptPlans, stepCounts)
133-
}
134-
135-
log(" $runtimeName | loading model...")
136-
val weights = try {
137-
loadWeights<FP32>(ctx, FP32::class, modelPathStr)
138-
} catch (e: Exception) {
139-
log(" $runtimeName | model load failed: ${e.message}")
140-
return skipAll(promptPlans, stepCounts, "Model load failed: ${e.message}")
141-
}
142-
143-
val bridge = createGpuBridge<FP32>(ctx)
144-
val backend = if (bridge != null) {
145-
log(" $runtimeName | using GPU attention backend")
146-
GpuAttentionBackend<FP32>(ctx, bridge, weights, FP32::class)
147-
} else {
148-
log(" $runtimeName | GPU bridge unavailable, falling back to CPU attention")
149-
CpuAttentionBackend<FP32>(ctx, weights, FP32::class)
150-
}
151-
152-
@Suppress("DEPRECATION")
153-
val runtime = LlamaRuntime<FP32>(ctx, weights, backend, FP32::class)
154-
log(" $runtimeName | model loaded")
155-
156-
return benchmarkCases(runtimeName, runtime, promptPlans, stepCounts, warmupRuns, measuredRuns)
157-
}
158-
159-
private fun skipAll(
160-
promptPlans: List<PromptPlan>,
161-
stepCounts: List<Int>,
162-
reason: String = "$runtimeName backend unavailable.",
163-
): List<BenchmarkCaseResult> = stepCounts.flatMap { steps ->
164-
promptPlans.map { (prompt, promptTokens) ->
165-
BenchmarkCaseResult(
166-
caseId = "$runtimeName:${prompt.label}:$steps",
167-
status = BenchmarkCaseStatus.SKIPPED,
168-
runtime = runtimeName,
169-
promptLabel = prompt.label,
170-
promptTokenCount = promptTokens.size,
171-
steps = steps,
172-
metrics = emptyList(),
173-
notes = listOf(reason),
174-
)
175-
}
176-
}
177-
}
178-
179-
// ── Shared helpers ──
180-
181-
internal suspend fun <T : DType> loadWeights(
182-
ctx: ExecutionContext,
183-
dtype: kotlin.reflect.KClass<T>,
184-
modelPathStr: String,
185-
): LlamaRuntimeWeights<T> {
186-
val modelPath = Path(modelPathStr)
187-
val ingestion = LlamaIngestion<T>(
188-
ctx = ctx,
189-
dtype = dtype,
190-
config = LlamaLoadConfig(
74+
val weights = DecoderGgufWeightLoader(
75+
sourceProvider = { SystemFileSystem.source(modelPath).buffered() },
19176
quantPolicy = QuantPolicy.DEQUANTIZE_TO_FP32,
192-
allowQuantized = false,
193-
),
194-
)
195-
return ingestion.load {
196-
SystemFileSystem.source(modelPath).buffered()
197-
}
198-
}
77+
).loadToMap<FP32, Float>(ctx)
78+
val model = LlamaNetworkLoader.fromWeights(weights)
79+
val runtime = OptimizedLLMRuntime(
80+
model = model,
81+
ctx = ctx,
82+
mode = OptimizedLLMMode.DIRECT,
83+
dtype = FP32::class,
84+
bos = weights.metadata.bosTokenId,
85+
)
86+
log(" $runtimeName | model loaded")
19987

200-
internal fun benchmarkCases(
201-
runtimeName: String,
202-
runtime: LlamaRuntime<FP32>,
203-
promptPlans: List<PromptPlan>,
204-
stepCounts: List<Int>,
205-
warmupRuns: Int,
206-
measuredRuns: Int,
207-
): List<BenchmarkCaseResult> {
208-
val results = mutableListOf<BenchmarkCaseResult>()
209-
for (steps in stepCounts) {
210-
for ((prompt, promptTokens) in promptPlans) {
211-
log(" $runtimeName | prompt=${prompt.label} steps=$steps | warming up ($warmupRuns runs)...")
212-
repeat(warmupRuns) { i ->
213-
runtime.reset()
214-
runtime.generate(promptTokens, steps, 0.0f) { _ -> }
215-
log(" warmup ${i + 1}/$warmupRuns done")
216-
}
217-
log(" $runtimeName | prompt=${prompt.label} steps=$steps | measuring ($measuredRuns runs)...")
218-
val measurements = (1..measuredRuns).map { i ->
219-
val ms = measureTime {
88+
val results = mutableListOf<BenchmarkCaseResult>()
89+
for (steps in stepCounts) {
90+
for ((prompt, promptTokens) in promptPlans) {
91+
log(" $runtimeName | prompt=${prompt.label} steps=$steps | warming up ($warmupRuns runs)...")
92+
repeat(warmupRuns) { i ->
22093
runtime.reset()
22194
runtime.generate(promptTokens, steps, 0.0f) { _ -> }
222-
}.inWholeMilliseconds
223-
log(" measured $i/$measuredRuns: ${ms}ms")
224-
ms
225-
}.sorted()
226-
227-
val medianMillis = measurements[measuredRuns / 2].coerceAtLeast(1)
228-
val throughput = steps.toDouble() / medianMillis * 1000.0
229-
log(" $runtimeName | prompt=${prompt.label} steps=$steps | median=${medianMillis}ms throughput=${formatDouble2(throughput)} tok/s")
230-
231-
results += BenchmarkCaseResult(
232-
caseId = "$runtimeName:${prompt.label}:$steps",
233-
status = BenchmarkCaseStatus.SUCCESS,
234-
runtime = runtimeName,
235-
promptLabel = prompt.label,
236-
promptTokenCount = promptTokens.size,
237-
steps = steps,
238-
metrics = listOf(
239-
BenchmarkMetric("throughput", throughput, "tok/s"),
240-
BenchmarkMetric("median_duration", medianMillis.toDouble(), "ms"),
241-
),
242-
)
95+
log(" warmup ${i + 1}/$warmupRuns done")
96+
}
97+
log(" $runtimeName | prompt=${prompt.label} steps=$steps | measuring ($measuredRuns runs)...")
98+
val measurements = (1..measuredRuns).map { i ->
99+
val ms = measureTime {
100+
runtime.reset()
101+
runtime.generate(promptTokens, steps, 0.0f) { _ -> }
102+
}.inWholeMilliseconds
103+
log(" measured $i/$measuredRuns: ${ms}ms")
104+
ms
105+
}.sorted()
106+
107+
val medianMillis = measurements[measuredRuns / 2].coerceAtLeast(1)
108+
val throughput = steps.toDouble() / medianMillis * 1000.0
109+
log(" $runtimeName | prompt=${prompt.label} steps=$steps | median=${medianMillis}ms throughput=${formatDouble2(throughput)} tok/s")
110+
111+
results += BenchmarkCaseResult(
112+
caseId = "$runtimeName:${prompt.label}:$steps",
113+
status = BenchmarkCaseStatus.SUCCESS,
114+
runtime = runtimeName,
115+
promptLabel = prompt.label,
116+
promptTokenCount = promptTokens.size,
117+
steps = steps,
118+
metrics = listOf(
119+
BenchmarkMetric("throughput", throughput, "tok/s"),
120+
BenchmarkMetric("median_duration", medianMillis.toDouble(), "ms"),
121+
),
122+
)
123+
}
243124
}
125+
return results
244126
}
245-
return results
246127
}
247128

248-
// ── Scenario ──
249-
250-
internal class NativeBackendThroughputScenario : BenchmarkScenario {
251-
override val id: String = "native-backend-throughput"
252-
override val description: String = "Compare CPU vs Metal vs MLX backend throughput on native macOS."
129+
internal class NativeCpuThroughputScenario : BenchmarkScenario {
130+
override val id: String = "native-cpu-throughput"
131+
override val description: String = "DSL CPU throughput on native (macOS)."
253132

254133
private val prompts: List<NamedPrompt> = listOf(
255134
NamedPrompt("short", "Hello"),
@@ -275,29 +154,20 @@ internal class NativeBackendThroughputScenario : BenchmarkScenario {
275154
}
276155
log("Prompts tokenized: ${promptPlans.joinToString { "${it.prompt.label}(${it.promptTokens.size} tokens)" }}")
277156

278-
val adapters: List<NativeLlamaAdapter> = buildList {
279-
add(CpuNativeLlamaAdapter(modelPathStr))
280-
add(GpuNativeLlamaAdapter(modelPathStr, "Metal", ::createMetalContext))
281-
add(GpuNativeLlamaAdapter(modelPathStr, "MLX", ::createMlxContext))
282-
}
283-
284-
val results = mutableListOf<BenchmarkCaseResult>()
285-
for ((index, adapter) in adapters.withIndex()) {
286-
log("=== Backend ${index + 1}/${adapters.size}: ${adapter.runtimeName} ===")
287-
val adapterResults = adapter.runAllCases(
288-
promptPlans = promptPlans,
289-
stepCounts = request.steps,
290-
warmupRuns = request.warmupRuns,
291-
measuredRuns = request.measuredRuns,
292-
)
293-
results += adapterResults
294-
val successCount = adapterResults.count { it.status == BenchmarkCaseStatus.SUCCESS }
295-
log("${adapter.runtimeName} finished: $successCount/${adapterResults.size} cases succeeded")
296-
}
157+
val adapter = CpuNativeDslAdapter(modelPathStr)
158+
log("=== Backend: ${adapter.runtimeName} ===")
159+
val results = adapter.runAllCases(
160+
promptPlans = promptPlans,
161+
stepCounts = request.steps,
162+
warmupRuns = request.warmupRuns,
163+
measuredRuns = request.measuredRuns,
164+
)
165+
val successCount = results.count { it.status == BenchmarkCaseStatus.SUCCESS }
166+
log("${adapter.runtimeName} finished: $successCount/${results.size} cases succeeded")
297167

298168
val finishedAt = epochMillis()
299169
val elapsedSec = (finishedAt - startedAt) / 1000.0
300-
log("All backends complete. Total elapsed: ${formatDouble1(elapsedSec)}s")
170+
log("Backend complete. Total elapsed: ${formatDouble1(elapsedSec)}s")
301171

302172
return BenchmarkRunResult(
303173
scenarioId = id,
@@ -312,10 +182,8 @@ internal class NativeBackendThroughputScenario : BenchmarkScenario {
312182
}
313183
}
314184

315-
// ── Orchestrator ──
316-
317185
class NativeBenchmarkOrchestrator : BenchmarkRunner<BenchmarkRunRequest, BenchmarkRunResult> {
318-
private val scenario = NativeBackendThroughputScenario()
186+
private val scenario = NativeCpuThroughputScenario()
319187

320188
override suspend fun run(config: BenchmarkRunRequest): BenchmarkRunResult {
321189
return scenario.execute(config)
@@ -326,8 +194,6 @@ class NativeBenchmarkOrchestrator : BenchmarkRunner<BenchmarkRunRequest, Benchma
326194
)
327195
}
328196

329-
// ── Console reporter (matches JVM format) ──
330-
331197
object NativeConsoleReporter {
332198
fun render(result: BenchmarkRunResult) {
333199
println("[BENCH] Scenario: ${result.scenarioId}")

0 commit comments

Comments
 (0)