@@ -135,17 +135,17 @@ def infer(self, full_data, window_interval=None, window_step=None, **_):
135135
136136class InferenceManager :
137137 ACCELERATE_MODEL_ID = "sundial"
138- DEFAULT_DEVICE = "cpu"
139- # DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138+ # DEFAULT_DEVICE = "cpu"
139+ DEFAULT_DEVICE = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
140140 DEFAULT_POOL_SIZE = (
141141 0 # TODO: Remove these parameter by sampling model inference consumption
142142 )
143143 WAITING_INTERVAL_IN_MS = (
144144 AINodeDescriptor ().get_config ().get_ain_inference_batch_interval_in_ms ()
145145 ) # How often to check for requests in the result queue
146146
147- def __init__ (self , model_manager : ModelManager ):
148- self ._model_manager = model_manager
147+ def __init__ (self ):
148+ self ._model_manager = ModelManager ()
149149 self ._result_queue = mp .Queue ()
150150 self ._result_wrapper_map = {}
151151 self ._result_wrapper_lock = threading .RLock ()
@@ -165,14 +165,11 @@ def _init_inference_request_pool(self):
165165 """
166166 self ._request_pool_map [self .ACCELERATE_MODEL_ID ] = []
167167 for idx in range (self .DEFAULT_POOL_SIZE ):
168- sundial_model = self ._model_manager .load_model (
169- self .ACCELERATE_MODEL_ID , {}
170- ).to (self .DEFAULT_DEVICE )
171168 sundial_config = SundialConfig ()
172169 request_queue = mp .Queue ()
173170 request_pool = InferenceRequestPool (
174171 pool_id = idx ,
175- model = sundial_model ,
172+ model_id = self . ACCELERATE_MODEL_ID ,
176173 config = sundial_config ,
177174 request_queue = request_queue ,
178175 result_queue = self ._result_queue ,
@@ -223,7 +220,8 @@ def _run(
223220 data = full_data [1 ][0 ]
224221 if data .dtype .byteorder not in ("=" , "|" ):
225222 data = data .byteswap ().newbyteorder ()
226- inputs = torch .tensor (data ).unsqueeze (0 ).float ().to (self .DEFAULT_DEVICE )
223+ # the inputs should be on CPU before passing to the inference request
224+ inputs = torch .tensor (data ).unsqueeze (0 ).float ().to ("cpu" )
227225 infer_req = InferenceRequest (
228226 req_id = _generate_req_id (),
229227 inputs = inputs ,
0 commit comments