Skip to content

Commit 3f28cc2

Browse files
Merge pull request #377 from SKaiNET-developers/feature/374-tool-calling
Feature/374 tool calling
2 parents 56edafd + e8efb99 commit 3f28cc2

27 files changed

Lines changed: 1910 additions & 23 deletions

File tree

ARCHITECTURE.md

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# SKaiNET Architecture: Where Agentic AI Fits
2+
3+
## The Core Question
4+
5+
After implementing tool calling support for KLlama, the question arises: **how does the agentic/tool-calling layer relate to the deep learning foundation?** Is it "real ML" or a higher-level orchestration concern?
6+
7+
**Answer**: Agentic AI is **not a deep learning primitive** — it's a **higher-level architectural pattern** that *consumes* the ML inference layer. The LLM (transformer forward pass, attention, embeddings) is pure deep learning. The agent loop that wraps it (chat formatting, tool parsing, execution, re-prompting) is application-level orchestration. Both are essential — one without the other is either a raw token generator or a tool executor with no intelligence.
8+
9+
---
10+
11+
## Diagram 1 — Full SKaiNET Layer Cake
12+
13+
All modules organized by abstraction level, with the agentic layer at the top:
14+
15+
```mermaid
16+
graph TB
17+
subgraph APP["Application Layer"]
18+
CLI["skainet-kllama-cli<br/>--chat / --agent"]
19+
end
20+
21+
subgraph AGENTIC["Agentic AI Layer (skainet-kllama-agent, orchestration, not ML)"]
22+
IR["InferenceRuntime&lt;T&gt;"]
23+
AL["AgentLoop&lt;T&gt;"]
24+
CT["ChatTemplate<br/>Llama3ChatTemplate / ChatMLTemplate"]
25+
TR["ToolRegistry"]
26+
TCP["ToolCallParser"]
27+
GEN["generateUntilStop()"]
28+
end
29+
30+
subgraph INFERENCE["Inference Runtime Layer (skainet-kllama, ML forward pass)"]
31+
LR["LlamaRuntime&lt;T&gt;"]
32+
AB["AttentionBackend&lt;T&gt;<br/>CpuAttentionBackend / GpuAttentionBackend"]
33+
KV["KvCache<br/>HeapKvCache"]
34+
TOK["GGUFTokenizer"]
35+
end
36+
37+
subgraph IO["Model I/O Layer"]
38+
GGUF["skainet-io-gguf"]
39+
ST["skainet-io-safetensors"]
40+
ONNX["skainet-io-onnx"]
41+
end
42+
43+
subgraph COMPILE["Compilation Layer"]
44+
CC["skainet-compile-core<br/>Tape Recording"]
45+
CD["skainet-compile-dag<br/>Graph Optimization"]
46+
HLO["skainet-compile-hlo<br/>StableHLO Lowering"]
47+
CGEN["skainet-compile-c<br/>C99 Codegen"]
48+
end
49+
50+
subgraph LANG["Tensor & NN Primitives Layer"]
51+
LC["skainet-lang-core<br/>Tensor&lt;T,V&gt;, DType, Shape"]
52+
NN["NN Layers<br/>Embedding, RMSNormalization, Linear"]
53+
OPS["Operators<br/>matmul, silu, softmax"]
54+
end
55+
56+
subgraph BACKEND["Backend Execution Layer"]
57+
CPU["skainet-backend-cpu<br/>DirectCpuExecutionContext<br/>JDK 21 Vector API / SIMD"]
58+
end
59+
60+
CLI --> AL
61+
AL --> CT
62+
AL --> TR
63+
AL --> TCP
64+
AL --> GEN
65+
AL --> IR
66+
GEN --> IR
67+
LR -.->|implements| IR
68+
LR --> AB
69+
AB --> KV
70+
LR --> TOK
71+
LR --> GGUF
72+
GGUF --> LC
73+
CC --> LC
74+
CD --> CC
75+
HLO --> CD
76+
NN --> LC
77+
OPS --> LC
78+
LC --> CPU
79+
```
80+
81+
---
82+
83+
## Diagram 2 — Agent Loop Data Flow
84+
85+
The generate-parse-execute cycle that makes the system "agentic":
86+
87+
```mermaid
88+
sequenceDiagram
89+
participant User
90+
participant AgentLoop
91+
participant ChatTemplate
92+
participant LlamaRuntime
93+
participant ToolCallParser
94+
participant ToolRegistry
95+
participant Tool
96+
97+
User->>AgentLoop: "What is 42 * 17?"
98+
99+
loop Up to maxToolRounds
100+
AgentLoop->>ChatTemplate: apply(messages + toolDefs)
101+
ChatTemplate-->>AgentLoop: formatted prompt string
102+
103+
AgentLoop->>LlamaRuntime: generateUntilStop(tokens)
104+
105+
Note over LlamaRuntime: ML BOUNDARY<br/>Embedding → Transformer Layers<br/>→ RoPE + Attention + KV Cache<br/>→ FFN (SiLU) → RMSNorm → Logits → Sample
106+
107+
LlamaRuntime-->>AgentLoop: "I'll calculate that.<br/>{\"name\":\"calculator\",\"arguments\":{\"expression\":\"42*17\"}}"
108+
109+
AgentLoop->>ToolCallParser: parse(response)
110+
ToolCallParser-->>AgentLoop: [ToolCall("calculator", {expression: "42*17"})]
111+
112+
AgentLoop->>ToolRegistry: execute(toolCall)
113+
ToolRegistry->>Tool: execute({expression: "42*17"})
114+
Tool-->>ToolRegistry: "714"
115+
ToolRegistry-->>AgentLoop: "714"
116+
117+
Note over AgentLoop: Append tool result as ChatMessage<br/>with role=TOOL, continue loop
118+
end
119+
120+
AgentLoop-->>User: "42 * 17 = 714"
121+
```
122+
123+
---
124+
125+
## Diagram 3 — ML vs Orchestration Boundary
126+
127+
What is deep learning and what is application architecture:
128+
129+
```mermaid
130+
graph LR
131+
subgraph ORCHESTRATION["Higher-Level: Orchestration"]
132+
direction TB
133+
A1["AgentLoop&lt;T&gt;<br/><i>control flow</i>"]
134+
A2["ChatTemplate<br/><i>string formatting</i>"]
135+
A3["ToolCallParser<br/><i>regex + JSON parsing</i>"]
136+
A4["ToolRegistry<br/><i>dispatch table</i>"]
137+
A5["ChatMessage / ChatRole<br/><i>data structures</i>"]
138+
end
139+
140+
subgraph ML["Deep Learning: Math"]
141+
direction TB
142+
M1["LlamaRuntime.forward()<br/><i>transformer decoder</i>"]
143+
M2["Embedding lookup"]
144+
M3["RoPE + Multi-Head Attention"]
145+
M4["SiLU-gated FFN"]
146+
M5["RMSNormalization"]
147+
M6["Softmax sampling"]
148+
M7["KvCache management"]
149+
M8["Tensor&lt;T,V&gt; operations<br/><i>matmul, add, silu</i>"]
150+
M9["SIMD kernels<br/><i>JDK 21 Vector API</i>"]
151+
end
152+
153+
ORCHESTRATION -->|"calls"| ML
154+
ML -->|"returns tokens"| ORCHESTRATION
155+
156+
style ORCHESTRATION fill:#ffe0e0,stroke:#cc0000
157+
style ML fill:#e0ffe0,stroke:#00aa00
158+
```
159+
160+
---
161+
162+
## Key Design Insights
163+
164+
### The agent layer adds no trainable parameters
165+
166+
It's pure control flow. The "intelligence" comes entirely from the LLM weights loaded from GGUF files via `LlamaWeightLoader`. `AgentLoop` decides *when* to call the model, not *what* the model says. The orchestration layer is stateless in the ML sense — it holds conversation history (`List<ChatMessage>`) but no learned weights.
167+
168+
### Why it matters anyway
169+
170+
Without the agent loop, the model is a one-shot text completer — you feed it tokens, it predicts the next ones, done. With it, the model can reason over multiple steps, call external tools, and incorporate real-world data. The same `LlamaRuntime<T>` that powers `--chat` mode becomes an autonomous agent in `--agent` mode, simply by wrapping it in `AgentLoop<T>`.
171+
172+
### The clean boundary
173+
174+
`InferenceRuntime<T>.forward(tokenId: Int): Tensor<T, Float>` is the ML boundary. The agent module (`skainet-kllama-agent`) defines this interface, and concrete runtimes like `LlamaRuntimeInterface<T>` extend it. Everything below (tensors, attention, SIMD kernels in `skainet-backend-cpu`) is deep learning. Everything above (chat formatting in `ChatTemplate`, tool parsing in `ToolCallParser`, the agent loop in `AgentLoop`) is software engineering orchestration.
175+
176+
```
177+
┌──────────────────────────────────┐
178+
│ AgentLoop / ChatTemplate / CLI │ ← orchestration (skainet-kllama-agent)
179+
├──────────────────────────────────┤
180+
│ InferenceRuntime<T>.forward() │ ← THE BOUNDARY
181+
├──────────────────────────────────┤
182+
│ LlamaRuntimeInterface<T> │ ← extends InferenceRuntime (skainet-kllama)
183+
│ Attention / FFN / KvCache │ ← deep learning
184+
│ Tensor<T,V> / SIMD kernels │
185+
└──────────────────────────────────┘
186+
```
187+
188+
### Both layers are in `commonMain`
189+
190+
The agent layer is multiplatform Kotlin, not JVM-specific. `AgentLoop`, `ChatTemplate`, `ToolRegistry`, `ToolCallParser`, and all supporting types live in `skainet-kllama-agent/src/commonMain/`. The same agent loop runs on JVM (with Vector API SIMD), Native, and WASM targets — the only platform-specific code is the backend execution layer (`skainet-backend-cpu`) and the CLI entry point (`skainet-kllama-cli`).
191+
192+
---
193+
194+
## Module Reference
195+
196+
| Layer | Module | Key Types |
197+
|-------|--------|-----------|
198+
| Application | `skainet-apps:skainet-kllama-cli` | `Main.kt` (`--chat`, `--agent`) |
199+
| Agentic | `skainet-apps:skainet-kllama-agent` | `InferenceRuntime<T>`, `AgentLoop<T>`, `ChatTemplate`, `Llama3ChatTemplate`, `ChatMLTemplate`, `ToolRegistry`, `ToolCallParser`, `ToolCall`, `Tool`, `ToolDefinition`, `ChatMessage`, `ChatRole`, `GenerateResult`, `generateUntilStop()`, `sampleFromLogits()` |
200+
| Inference | `skainet-apps:skainet-kllama` | `LlamaRuntime<T>`, `LlamaRuntimeInterface<T>` (extends `InferenceRuntime<T>`), `AttentionBackend<T>`, `CpuAttentionBackend<T>`, `GpuAttentionBackend<T>`, `KvCache`, `HeapKvCache`, `GGUFTokenizer` |
201+
| Model I/O | `skainet-io:skainet-io-gguf`, `skainet-io:skainet-io-safetensors`, `skainet-io:skainet-io-onnx` | `LlamaWeightLoader`, `LlamaRuntimeWeights<T>` |
202+
| Compilation | `skainet-compile:skainet-compile-core`, `skainet-compile-dag`, `skainet-compile-hlo`, `skainet-compile-c` | Tape recording, graph optimization, StableHLO lowering, C99 codegen |
203+
| Tensor/NN | `skainet-lang:skainet-lang-core` | `Tensor<T,V>`, `Shape`, `DType`, `Embedding`, `Linear`, `RMSNormalization` |
204+
| Backend | `skainet-backends:skainet-backend-cpu` | `DirectCpuExecutionContext`, `DefaultCpuOps` |

settings.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ include("skainet-apps:skainet-tensor-tools")
7575
include("skainet-apps:skainet-llm")
7676
include("skainet-apps:skainet-bert")
7777
include("skainet-apps:skainet-kllama")
78+
include("skainet-apps:skainet-kllama-agent")
7879
include("skainet-apps:skainet-kllama-cli")
7980
include("skainet-apps:skainet-kgemma")
8081
include("skainet-apps:skainet-kbert-cli")
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import org.jetbrains.kotlin.gradle.ExperimentalKotlinGradlePluginApi
2+
import org.jetbrains.kotlin.gradle.ExperimentalWasmDsl
3+
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
4+
5+
plugins {
6+
alias(libs.plugins.kotlinMultiplatform)
7+
alias(libs.plugins.androidLibrary)
8+
alias(libs.plugins.kotlinSerialization)
9+
}
10+
11+
kotlin {
12+
jvmToolchain(21)
13+
14+
androidTarget {
15+
publishLibraryVariants("release")
16+
@OptIn(ExperimentalKotlinGradlePluginApi::class)
17+
compilerOptions {
18+
jvmTarget.set(JvmTarget.JVM_11)
19+
}
20+
}
21+
22+
linuxX64()
23+
linuxArm64()
24+
macosArm64()
25+
jvm()
26+
27+
js {
28+
browser()
29+
}
30+
31+
@OptIn(ExperimentalWasmDsl::class)
32+
wasmJs {
33+
browser()
34+
}
35+
36+
sourceSets {
37+
commonMain.dependencies {
38+
implementation(project(":skainet-lang:skainet-lang-core"))
39+
implementation(libs.kotlinx.serialization.json)
40+
}
41+
42+
commonTest.dependencies {
43+
implementation(libs.kotlin.test)
44+
}
45+
}
46+
}
47+
48+
android {
49+
namespace = "sk.ainet.apps.kllama.agent"
50+
compileSdk = libs.versions.android.compileSdk.get().toInt()
51+
52+
defaultConfig {
53+
minSdk = libs.versions.android.minSdk.get().toInt()
54+
}
55+
compileOptions {
56+
sourceCompatibility = JavaVersion.VERSION_11
57+
targetCompatibility = JavaVersion.VERSION_11
58+
}
59+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
package sk.ainet.apps.kllama.agent
2+
3+
import kotlin.math.exp
4+
import kotlin.random.Random
5+
import sk.ainet.lang.tensor.Tensor
6+
import sk.ainet.lang.tensor.data.FloatArrayTensorData
7+
import sk.ainet.lang.types.DType
8+
9+
/**
10+
* Generate tokens until an EOS token is produced or [maxTokens] is reached.
11+
*
12+
* Unlike batch generation, this function:
13+
* - Stops when the model emits [eosTokenId]
14+
* - Does NOT prepend BOS automatically (the caller is responsible for encoding the
15+
* full prompt including special tokens via the chat template)
16+
* - Returns a [GenerateResult] with all generated tokens and decoded text
17+
*
18+
* @param prompt Encoded prompt token IDs (should include BOS if needed).
19+
* @param maxTokens Maximum number of tokens to generate.
20+
* @param eosTokenId The EOS token ID to stop on.
21+
* @param temperature Sampling temperature (0 = greedy).
22+
* @param random Random generator for sampling.
23+
* @param onToken Optional callback invoked for each generated token.
24+
* @param decode Optional function to decode a token ID to a string.
25+
*/
26+
public fun <T : DType> InferenceRuntime<T>.generateUntilStop(
27+
prompt: IntArray,
28+
maxTokens: Int,
29+
eosTokenId: Int,
30+
temperature: Float = 0.8f,
31+
random: Random = Random.Default,
32+
onToken: ((Int) -> Unit)? = null,
33+
decode: ((Int) -> String)? = null
34+
): GenerateResult {
35+
// Feed prompt tokens through the model
36+
var lastLogits: Tensor<T, Float>? = null
37+
for (tokenId in prompt) {
38+
lastLogits = forward(tokenId)
39+
}
40+
41+
if (lastLogits == null) {
42+
return GenerateResult(emptyList(), "", false)
43+
}
44+
45+
val generated = mutableListOf<Int>()
46+
val textBuilder = StringBuilder()
47+
var stoppedByEos = false
48+
49+
var logits: Tensor<T, Float> = lastLogits
50+
for (step in 0 until maxTokens) {
51+
val nextToken = sampleFromLogits<T>(logits, temperature, random)
52+
53+
if (nextToken == eosTokenId) {
54+
stoppedByEos = true
55+
break
56+
}
57+
58+
generated.add(nextToken)
59+
onToken?.invoke(nextToken)
60+
decode?.let { textBuilder.append(it(nextToken)) }
61+
62+
logits = forward(nextToken)
63+
}
64+
65+
return GenerateResult(generated, textBuilder.toString(), stoppedByEos)
66+
}
67+
68+
/**
69+
* Sample a token ID from a logits tensor.
70+
*
71+
* @param logits The logits tensor (1D, vocabSize).
72+
* @param temperature Sampling temperature. Values <= 1e-6 use greedy (argmax).
73+
* @param random Random generator.
74+
* @return The sampled token ID.
75+
*/
76+
public fun <T : DType> sampleFromLogits(
77+
logits: Tensor<T, Float>,
78+
temperature: Float,
79+
random: Random = Random.Default
80+
): Int {
81+
val buf = logits.toFloatArray()
82+
83+
// Greedy (argmax) for near-zero temperature
84+
if (temperature <= 1e-6f) {
85+
var best = 0
86+
var bestVal = buf[0]
87+
for (i in 1 until buf.size) {
88+
if (buf[i] > bestVal) {
89+
bestVal = buf[i]
90+
best = i
91+
}
92+
}
93+
return best
94+
}
95+
96+
// Temperature-scaled softmax sampling
97+
var maxLogit = Float.NEGATIVE_INFINITY
98+
for (i in buf.indices) {
99+
val v = buf[i] / temperature
100+
buf[i] = v
101+
if (v > maxLogit) maxLogit = v
102+
}
103+
var sum = 0f
104+
for (i in buf.indices) {
105+
val e = exp((buf[i] - maxLogit).toDouble()).toFloat()
106+
buf[i] = e
107+
sum += e
108+
}
109+
val r = random.nextFloat() * sum
110+
var acc = 0f
111+
for (i in buf.indices) {
112+
acc += buf[i]
113+
if (acc >= r) return i
114+
}
115+
return buf.lastIndex
116+
}
117+
118+
/**
119+
* Extract a FloatArray from a tensor, using the fast path if available.
120+
*/
121+
private fun <T : DType> Tensor<T, Float>.toFloatArray(): FloatArray {
122+
val data = this.data
123+
if (data is FloatArrayTensorData<*>) return data.buffer.copyOf()
124+
return data.copyToFloatArray()
125+
}

0 commit comments

Comments
 (0)