Skip to content

Commit cd38dfb

Browse files
authored
[https://nvbugs/6226933][fix] canonicalize multimodal cache-key serialization to prevent hash collisions (#14800)
Signed-off-by: venkywonka <23023424+venkywonka@users.noreply.github.com>
1 parent 9653450 commit cd38dfb

3 files changed

Lines changed: 285 additions & 54 deletions

File tree

tensorrt_llm/inputs/multimodal.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
default_hasher = blake3
1919
_INT32_MAX = 2**31 - 1
2020

21+
# Versioned tag prefixed to every content hash so the canonical, self-describing
22+
# serialization scheme can evolve without silently reusing stale cache keys.
23+
_HASH_SCHEME_TAG = b"trtllm.mm.hash.v1"
24+
2125

2226
def strip_mm_data_for_generation(mm_data: Dict[str, Any]) -> None:
2327
"""Clear `mm_data` in place, retaining only `mrope_config.mrope_position_deltas`.
@@ -666,21 +670,10 @@ class MultimodalServerConfig():
666670

667671
def _update_hash(hasher, item: object) -> None:
668672
"""Hash the content of a multimodal item into the provided hasher."""
673+
hasher.update(_HASH_SCHEME_TAG)
669674
if isinstance(item, BaseModalityData):
670675
item.update_hash(hasher)
671676
return
672-
if isinstance(item, torch.Tensor):
673-
item = item.detach().cpu().contiguous()
674-
hasher.update(serialize_item(item))
675-
return
676-
if isinstance(item, list):
677-
for element in item:
678-
hasher.update(b"<frame>")
679-
if isinstance(element, torch.Tensor):
680-
element = element.detach().cpu().contiguous()
681-
hasher.update(serialize_item(element))
682-
return
683-
684677
hasher.update(serialize_item(item))
685678

686679

@@ -711,7 +704,6 @@ def apply_mm_hashes(
711704

712705
def _hash_item(item):
713706
"""Hash only the content of a multimodal item (no UUID)."""
714-
# TODO: possible hash collision w/ this simplified version (vllm/PR/17378)
715707
hasher = hash_lib()
716708
_update_hash(hasher, item)
717709
return hasher.hexdigest()

tensorrt_llm/inputs/multimodal_data.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
import struct
45
from dataclasses import dataclass
56
from typing import Any, Protocol
67

78
import numpy as np
89
import torch
910
from PIL import Image
1011

12+
# Video metadata fields that participate in the cache-key hash. These describe
13+
# how frames were sampled and therefore change the model-visible content.
14+
_VIDEO_HASH_METADATA_FIELDS = (
15+
"frames_indices",
16+
"fps",
17+
"duration",
18+
"total_num_frames",
19+
)
20+
1121

1222
class ContentHasher(Protocol):
1323
"""Hash object that accepts bytes."""
@@ -16,30 +26,93 @@ def update(self, data: bytes) -> None:
1626
"""Update the hash with raw bytes."""
1727

1828

29+
def _u8(value: int) -> bytes:
30+
"""Encode an unsigned 8-bit integer."""
31+
return value.to_bytes(1, "big", signed=False)
32+
33+
34+
def _u32(value: int) -> bytes:
35+
"""Encode an unsigned 32-bit big-endian integer."""
36+
return value.to_bytes(4, "big", signed=False)
37+
38+
39+
def _u64(value: int) -> bytes:
40+
"""Encode an unsigned 64-bit big-endian integer."""
41+
return value.to_bytes(8, "big", signed=False)
42+
43+
44+
def _len_prefixed(payload: bytes) -> bytes:
45+
"""Encode a byte payload prefixed with its u64 length."""
46+
return _u64(len(payload)) + payload
47+
48+
1949
def serialize_item(obj: object) -> bytes:
20-
"""Serialize a supported multimodal hash leaf to bytes."""
50+
"""Serialize a supported multimodal hash leaf to bytes.
51+
52+
The encoding is canonical and self-describing: every value is
53+
`[1-byte type tag][typed metadata][length-prefixed payload]` with all
54+
multi-byte integers big-endian. This prevents cache-key hash collisions
55+
between distinct values that happen to share a raw byte payload (for
56+
example transposed image dimensions or reshaped arrays).
57+
"""
2158
if isinstance(obj, str):
22-
return obj.encode("utf-8")
59+
return _u8(0x01) + _len_prefixed(obj.encode("utf-8"))
2360
if isinstance(obj, bytes):
24-
return obj
25-
if isinstance(obj, (int, float)):
26-
return np.array(obj).tobytes()
61+
return _u8(0x02) + _len_prefixed(obj)
62+
# bool must be checked before int: bool is a subclass of int.
63+
if isinstance(obj, bool):
64+
return _u8(0x05) + _u8(1 if obj else 0)
65+
if isinstance(obj, int):
66+
nbytes = (obj.bit_length() + 8) // 8 # +1 sign bit, then ceil-divide.
67+
return _u8(0x03) + _u8(nbytes) + obj.to_bytes(nbytes, "big", signed=True)
68+
if isinstance(obj, float):
69+
return _u8(0x04) + struct.pack(">d", obj)
2770

2871
if isinstance(obj, Image.Image):
29-
return np.array(obj.convert("RGBA")).tobytes()
30-
if isinstance(obj, torch.Tensor):
31-
return obj.numpy().tobytes()
32-
if isinstance(obj, np.ndarray):
33-
return obj.tobytes()
72+
width, height = obj.size
73+
payload = np.array(obj.convert("RGBA")).tobytes()
74+
return (
75+
_u8(0x10)
76+
+ _len_prefixed(obj.mode.encode("utf-8"))
77+
+ _u32(width)
78+
+ _u32(height)
79+
+ _len_prefixed(payload)
80+
)
81+
if isinstance(obj, (torch.Tensor, np.ndarray)):
82+
# The container (torch.Tensor vs np.ndarray) is not part of the content
83+
# identity -- only dtype, shape, and raw bytes are. Normalize both to a
84+
# contiguous NumPy array so identical content hashes identically.
85+
if isinstance(obj, torch.Tensor):
86+
obj = obj.detach().cpu().contiguous().numpy()
87+
array = np.ascontiguousarray(obj)
88+
parts = [
89+
_u8(0x11),
90+
_len_prefixed(array.dtype.str.encode("utf-8")),
91+
_u8(array.ndim),
92+
]
93+
parts.extend(_u64(dim) for dim in array.shape)
94+
parts.append(_len_prefixed(array.tobytes()))
95+
return b"".join(parts)
3496
if isinstance(obj, (tuple, list)):
35-
container_tag = b"T" if isinstance(obj, tuple) else b"L"
36-
parts = [container_tag, len(obj).to_bytes(8, "big", signed=False)]
37-
for item in obj:
38-
payload = serialize_item(item)
39-
parts.append(len(payload).to_bytes(8, "big", signed=False))
40-
parts.append(payload)
97+
# Ordered sequence; the container (tuple vs list) is not part of the
98+
# content identity.
99+
parts = [_u8(0x20), _u64(len(obj))]
100+
parts.extend(serialize_item(item) for item in obj)
101+
return b"".join(parts)
102+
if isinstance(obj, dict):
103+
parts = [_u8(0x22), _u64(len(obj))]
104+
for key in sorted(obj):
105+
parts.append(serialize_item(key))
106+
parts.append(serialize_item(obj[key]))
41107
return b"".join(parts)
42108

109+
if isinstance(obj, np.generic):
110+
# numpy scalar (e.g. np.int64 / np.float32 / np.bool_): normalize to the
111+
# equivalent Python scalar and recurse, so numpy-typed values hash
112+
# identically to their Python counterparts. In numpy 2.x these are not
113+
# subclasses of Python int/float/bool, so they bypass the checks above.
114+
return serialize_item(obj.item())
115+
43116
raise ValueError(f"Unsupported object type: {type(obj)}")
44117

45118

@@ -65,11 +138,8 @@ def __post_init__(self) -> None:
65138
self.sample_rate = int(self.sample_rate)
66139

67140
def update_hash(self, hasher: ContentHasher) -> None:
68-
samples = self.samples
69-
if isinstance(samples, torch.Tensor):
70-
samples = samples.detach().cpu().contiguous()
71141
hasher.update(b"<audio>")
72-
hasher.update(serialize_item((samples, self.sample_rate)))
142+
hasher.update(serialize_item((self.samples, self.sample_rate)))
73143

74144

75145
@dataclass
@@ -97,12 +167,12 @@ def __post_init__(self) -> None:
97167
raise TypeError("metadata must be a dictionary")
98168

99169
def update_hash(self, hasher: ContentHasher) -> None:
170+
hasher.update(b"<video>")
171+
# Sampling metadata is part of the model-visible cache identity.
172+
meta = {k: self.metadata[k] for k in _VIDEO_HASH_METADATA_FIELDS if k in self.metadata}
173+
hasher.update(serialize_item(meta))
100174
for frame in self.frames:
101175
hasher.update(b"<frame>")
102-
if isinstance(frame, torch.Tensor):
103-
frame = frame.detach().cpu().contiguous()
104176
hasher.update(serialize_item(frame))
105-
# Extend this to include metadata if fields such as sampled frame
106-
# indices become part of the model-visible cache identity.
107177
if self.audio is not None:
108178
self.audio.update_hash(hasher)

0 commit comments

Comments
 (0)