@@ -83,11 +83,14 @@ def __init__(
8383 response_queue_addr : str ,
8484 device_id : int ,
8585 diffusion_args : "VisualGenArgs" ,
86+ req_hmac_key : Optional [bytes ] = None ,
87+ resp_hmac_key : Optional [bytes ] = None ,
8688 ):
8789 self .request_queue_addr = request_queue_addr
8890 self .response_queue_addr = response_queue_addr
8991 self .device_id = device_id
9092 self .diffusion_args = diffusion_args
93+ self .resp_hmac_key = resp_hmac_key
9194
9295 self .pipeline = None # initialized in _load_pipeline
9396 self .requests_ipc = None
@@ -99,10 +102,10 @@ def __init__(
99102 if self .rank == 0 :
100103 logger .info (f"Worker { device_id } : Connecting to request queue" )
101104 self .requests_ipc = ZeroMqQueue (
102- (request_queue_addr , None ),
105+ (request_queue_addr , req_hmac_key ),
103106 is_server = False ,
104107 socket_type = zmq .PULL ,
105- use_hmac_encryption = False ,
108+ use_hmac_encryption = True ,
106109 )
107110 self .sender_thread = threading .Thread (target = self ._sender_loop , daemon = True )
108111 self .sender_thread .start ()
@@ -113,10 +116,10 @@ def _sender_loop(self):
113116 """Background thread for sending responses."""
114117 logger .info (f"Worker { self .device_id } : Connecting to response queue" )
115118 responses_ipc = ZeroMqQueue (
116- (self .response_queue_addr , None ),
119+ (self .response_queue_addr , self . resp_hmac_key ),
117120 is_server = False ,
118121 socket_type = zmq .PUSH ,
119- use_hmac_encryption = False ,
122+ use_hmac_encryption = True ,
120123 )
121124
122125 while True :
@@ -214,6 +217,8 @@ def run_diffusion_worker(
214217 response_queue_addr : str ,
215218 diffusion_args : "VisualGenArgs" ,
216219 log_level : str = "info" ,
220+ req_hmac_key : Optional [bytes ] = None ,
221+ resp_hmac_key : Optional [bytes ] = None ,
217222):
218223 """Entry point for worker process."""
219224 try :
@@ -248,6 +253,8 @@ def run_diffusion_worker(
248253 response_queue_addr = response_queue_addr ,
249254 device_id = device_id ,
250255 diffusion_args = diffusion_args ,
256+ req_hmac_key = req_hmac_key ,
257+ resp_hmac_key = resp_hmac_key ,
251258 )
252259 executor .serve_forever ()
253260 if executor .pipeline is not None :
0 commit comments