Skip to content

Commit b049986

Browse files
committed
Add ModelParameters.enableSwaFull() (--swa-full) launch flag
Add a valueless boolean model launch flag mirroring the existing enableFlashAttn() / ModelFlag.FLASH_ATTN pattern. Default off. --swa-full keeps the full-size sliding-window-attention KV cache so the SWA layers' KV becomes reusable across requests, restoring cross-request prompt-prefix cache reuse (pairs with setCacheReuse) at ~2x the SWA-layer KV RAM. Beneficial only for multi-request sessions sharing a prompt prefix, so it is opt-in. - ModelFlag.SWA_FULL("--swa-full") - ModelParameters.enableSwaFull() - ModelFlagTest: enum->string mapping row + enum count 34->35 - ModelParametersExtendedTest: enableSwaFull + not-by-default tests - CHANGELOG entry Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01SRXbL5RqW3B1XZRea3Rfc7
1 parent 6b7503d commit b049986

5 files changed

Lines changed: 31 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from version 5.0.0 onward. Pre-fork releases (`1.x`–`4.2.0`) were authored by
1919
- Explicit `setMmprojAuto(boolean)` and `setMmprojOffload(boolean)` controls, including the upstream `--no-mmproj-auto` and `--no-mmproj-offload` flags.
2020
- Per-request KV controls: `InferenceParameters.withSlotId(int)` and `withCacheReuse(int)`.
2121
- Per-request DRY sampling to `InferenceParameters` (`dry_multiplier`/`dry_base`/`dry_allowed_length`/`dry_penalty_last_n`/`dry_sequence_breakers`).
22+
- `ModelParameters.enableSwaFull()` (`--swa-full`): keep full-size SWA KV cache to enable cross-request prompt-prefix reuse.
2223
- Typed cache observability through `Usage.getCachedTokens()`, `Usage.getProcessedPromptTokens()`, `SlotMetrics`, and `ServerMetrics.getSlotMetrics()`.
2324
- Authenticated JSON `GET /metrics` and `GET /slots` endpoints on the embedded server.
2425

src/main/java/net/ladenthin/llama/args/ModelFlag.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ public enum ModelFlag {
2222
/** Enable Flash Attention. */
2323
FLASH_ATTN("--flash-attn"),
2424

25+
/** Keep the full-size sliding-window-attention (SWA) KV cache, enabling cross-request
26+
* prompt-prefix reuse (pairs with --cache-reuse) at ~2x the SWA-layer KV RAM. Default off.
27+
* Env: LLAMA_ARG_SWA_FULL. */
28+
SWA_FULL("--swa-full"),
29+
2530
/** Disable internal libllama performance timings. */
2631
NO_PERF("--no-perf"),
2732

src/main/java/net/ladenthin/llama/parameters/ModelParameters.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,17 @@ public ModelParameters enableFlashAttn() {
255255
return setFlag(ModelFlag.FLASH_ATTN);
256256
}
257257

258+
/**
259+
* Use the full-size SWA KV cache so the sliding-window layers' KV is reusable across requests
260+
* (restores prompt-prefix cache reuse with {@link #setCacheReuse(int)}); costs ~2x SWA-layer
261+
* KV RAM. Off by default; only beneficial for multi-request sessions sharing a prompt prefix.
262+
*
263+
* @return this builder
264+
*/
265+
public ModelParameters enableSwaFull() {
266+
return setFlag(ModelFlag.SWA_FULL);
267+
}
268+
258269
/**
259270
* Disable internal libllama performance timings (default: false).
260271
*

src/test/java/net/ladenthin/llama/args/ModelFlagTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public static Collection<Object[]> data() {
1919
return Arrays.asList(new Object[][] {
2020
{ModelFlag.NO_CONTEXT_SHIFT, "--no-context-shift"},
2121
{ModelFlag.FLASH_ATTN, "--flash-attn"},
22+
{ModelFlag.SWA_FULL, "--swa-full"},
2223
{ModelFlag.NO_PERF, "--no-perf"},
2324
{ModelFlag.ESCAPE, "--escape"},
2425
{ModelFlag.NO_ESCAPE, "--no-escape"},
@@ -66,7 +67,7 @@ public void testGetCliFlag(ModelFlag flag, String expectedCliFlag) {
6667

6768
@Test
6869
public void testEnumCount() {
69-
assertEquals(34, ModelFlag.values().length);
70+
assertEquals(35, ModelFlag.values().length);
7071
}
7172

7273
@ParameterizedTest(name = "{0} -> {1}")

src/test/java/net/ladenthin/llama/parameters/ModelParametersExtendedTest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,18 @@ public void testEnableFlashAttn() {
641641
assertThat(p.parameters.get("--flash-attn"), is(nullValue()));
642642
}
643643

644+
@Test
645+
public void testEnableSwaFull() {
646+
ModelParameters p = new ModelParameters().enableSwaFull();
647+
assertThat(p.parameters, hasKey("--swa-full"));
648+
assertThat(p.parameters.get("--swa-full"), is(nullValue()));
649+
}
650+
651+
@Test
652+
public void testSwaFullNotEnabledByDefault() {
653+
assertThat(new ModelParameters().parameters, not(hasKey("--swa-full")));
654+
}
655+
644656
@Test
645657
public void testDisablePerf() {
646658
ModelParameters p = new ModelParameters().disablePerf();

0 commit comments

Comments
 (0)