Skip to content

Commit 672c427

Browse files
committed
Enable loading onnx model from mlflow
1 parent 777048b commit 672c427

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

bats_ai/tasks/tasks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,7 @@ def _fully_local_inference(image_file, use_mlflow_model):
462462
'CPUExecutionProvider',
463463
],
464464
)
465+
model = onnx.load(onnx_filename)
465466
else:
466467
import mlflow
467468
import mlflow.onnx
@@ -516,7 +517,6 @@ def _fully_local_inference(image_file, use_mlflow_model):
516517
outputs = np.vstack(outputs)
517518
outputs = outputs.mean(axis=0)
518519

519-
model = onnx.load(onnx_filename)
520520
mapping = json.loads(model.metadata_props[0].value)
521521
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))]
522522

@@ -536,6 +536,7 @@ def predict_compressed(image_file):
536536
inference_mode = int(os.getenv('INFERENCE_MODE', 0))
537537
if inference_mode == 1:
538538
print('Using inference mode 1: file from mlflow')
539+
return _fully_local_inference(image_file, True)
539540
elif inference_mode == 2:
540541
print('Using inference mode 2: deployed mlflow model')
541542
else:

0 commit comments

Comments
 (0)