Skip to content

Commit c4d9577

Browse files
yibinl-nvidiasuyoggupta
authored andcommitted
[https://nvbugs/5922880][fix] Enable HMAC authentication in VisualGen ZMQ IPC channels (NVIDIA#12680)
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent cedaf8d commit c4d9577

2 files changed

Lines changed: 22 additions & 8 deletions

File tree

tensorrt_llm/_torch/visual_gen/executor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def __init__(
8383
response_queue_addr: str,
8484
device_id: int,
8585
diffusion_args: "VisualGenArgs",
86+
req_hmac_key: Optional[bytes] = None,
87+
resp_hmac_key: Optional[bytes] = None,
8688
):
8789
self.request_queue_addr = request_queue_addr
8890
self.response_queue_addr = response_queue_addr
8991
self.device_id = device_id
9092
self.diffusion_args = diffusion_args
93+
self.resp_hmac_key = resp_hmac_key
9194

9295
self.pipeline = None # initialized in _load_pipeline
9396
self.requests_ipc = None
@@ -99,10 +102,10 @@ def __init__(
99102
if self.rank == 0:
100103
logger.info(f"Worker {device_id}: Connecting to request queue")
101104
self.requests_ipc = ZeroMqQueue(
102-
(request_queue_addr, None),
105+
(request_queue_addr, req_hmac_key),
103106
is_server=False,
104107
socket_type=zmq.PULL,
105-
use_hmac_encryption=False,
108+
use_hmac_encryption=True,
106109
)
107110
self.sender_thread = threading.Thread(target=self._sender_loop, daemon=True)
108111
self.sender_thread.start()
@@ -113,10 +116,10 @@ def _sender_loop(self):
113116
"""Background thread for sending responses."""
114117
logger.info(f"Worker {self.device_id}: Connecting to response queue")
115118
responses_ipc = ZeroMqQueue(
116-
(self.response_queue_addr, None),
119+
(self.response_queue_addr, self.resp_hmac_key),
117120
is_server=False,
118121
socket_type=zmq.PUSH,
119-
use_hmac_encryption=False,
122+
use_hmac_encryption=True,
120123
)
121124

122125
while True:
@@ -214,6 +217,8 @@ def run_diffusion_worker(
214217
response_queue_addr: str,
215218
diffusion_args: "VisualGenArgs",
216219
log_level: str = "info",
220+
req_hmac_key: Optional[bytes] = None,
221+
resp_hmac_key: Optional[bytes] = None,
217222
):
218223
"""Entry point for worker process."""
219224
try:
@@ -248,6 +253,8 @@ def run_diffusion_worker(
248253
response_queue_addr=response_queue_addr,
249254
device_id=device_id,
250255
diffusion_args=diffusion_args,
256+
req_hmac_key=req_hmac_key,
257+
resp_hmac_key=resp_hmac_key,
251258
)
252259
executor.serve_forever()
253260
if executor.pipeline is not None:

tensorrt_llm/visual_gen/visual_gen.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
import asyncio
1616
import atexit
17+
import os
1718
import queue
1819
import socket
1920
import threading
@@ -86,6 +87,10 @@ def __init__(
8687
self.req_addr_connect = f"tcp://{self.host_ip}:{req_port}"
8788
self.resp_addr_connect = f"tcp://{self.host_ip}:{resp_port}"
8889

90+
# Generate shared HMAC keys for IPC authentication
91+
self.req_hmac_key = os.urandom(32)
92+
self.resp_hmac_key = os.urandom(32)
93+
8994
# IPC setup
9095
self.requests_ipc = None
9196
self.responses_ipc = None
@@ -122,6 +127,8 @@ def __init__(
122127
"request_queue_addr": self.req_addr_connect,
123128
"response_queue_addr": self.resp_addr_connect,
124129
"diffusion_args": self.diffusion_args,
130+
"req_hmac_key": self.req_hmac_key,
131+
"resp_hmac_key": self.resp_hmac_key,
125132
"log_level": logger.level,
126133
},
127134
)
@@ -205,16 +212,16 @@ def _init_ipc(self) -> bool:
205212
try:
206213
logger.info("DiffusionClient: Initializing IPC")
207214
self.requests_ipc = ZeroMqQueue(
208-
(self.request_queue_addr, None),
215+
(self.request_queue_addr, self.req_hmac_key),
209216
is_server=True,
210217
socket_type=zmq.PUSH,
211-
use_hmac_encryption=False,
218+
use_hmac_encryption=True,
212219
)
213220
self.responses_ipc = ZeroMqQueue(
214-
(self.response_queue_addr, None),
221+
(self.response_queue_addr, self.resp_hmac_key),
215222
is_server=True,
216223
socket_type=zmq.PULL,
217-
use_hmac_encryption=False,
224+
use_hmac_encryption=True,
218225
)
219226
logger.info("DiffusionClient: IPC ready")
220227
return True

0 commit comments

Comments
 (0)