Skip to content

Commit 411e6d6

Browse files
authored
Merge pull request #87 from FluffyAIcode/AgentMemory/v04-pr-k2a1-cuda-bf16-dtype-fix-8e7f
PR-K2.A.1-hotfix: cast KakeyaLattice round-tripped K/V back to resident dtype (CUDA bf16 crash)
2 parents c5e8449 + a8adbd2 commit 411e6d6

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

inference_engine/v04/dlm_restored_verifier.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,17 @@ def _round_trip_resident_through_compressor(
201201
# resident positions get the round-tripped values.
202202
K_out = K.clone()
203203
V_out = V.clone()
204-
K_out.index_copy_(-2, pos_tensor, K_round_tripped)
205-
V_out.index_copy_(-2, pos_tensor, V_round_tripped)
204+
# The compressor round-trip may upcast to fp32 (KakeyaLattice's
205+
# quantize/dequantize math runs in fp32 for numerical fidelity),
206+
# whereas the resident K/V cache is the model's compute dtype
207+
# (bf16 on CUDA). index_copy_ requires matching dtype (and device),
208+
# so cast the round-tripped tensors back before writing them in.
209+
K_out.index_copy_(
210+
-2, pos_tensor, K_round_tripped.to(device=K_out.device, dtype=K_out.dtype),
211+
)
212+
V_out.index_copy_(
213+
-2, pos_tensor, V_round_tripped.to(device=V_out.device, dtype=V_out.dtype),
214+
)
206215
return K_out, V_out
207216

208217

0 commit comments

Comments
 (0)