2525from merlin .schema import ColumnSchema , Schema
2626from merlin .systems .dag .ensemble import Ensemble
2727from merlin .systems .dag .ops .implicit import PredictImplicit
28- from merlin .systems .dag .runtimes .triton import TritonExecutorRuntime
2928from merlin .systems .triton .utils import run_triton_server
3029
3130TRITON_SERVER_PATH = shutil .which ("tritonserver" )
3534
3635
3736@pytest .mark .skipif (not TRITON_SERVER_PATH , reason = "triton server not found" )
38- @pytest .mark .parametrize ("runtime" , [None , TritonExecutorRuntime ()])
3937@pytest .mark .parametrize (
4038 "model_cls" ,
4139 [
4442 implicit .lmf .LogisticMatrixFactorization ,
4543 ],
4644)
47- def test_ensemble (model_cls , runtime , tmpdir ):
45+ def test_ensemble (model_cls , tmpdir ):
4846 model = model_cls ()
4947 n = 100
5048 user_items = csr_matrix (np .random .choice ([0 , 1 ], size = n * n , p = [0.9 , 0.1 ]).reshape (n , n ))
@@ -64,7 +62,7 @@ def test_ensemble(model_cls, runtime, tmpdir):
6462 triton_chain = input_schema .column_names >> implicit_op
6563
6664 triton_ens = Ensemble (triton_chain , input_schema )
67- ensemble_config , _ = triton_ens .export (tmpdir , runtime = runtime )
65+ ensemble_config , _ = triton_ens .export (tmpdir )
6866
6967 input_user_id = np .array ([[0 ], [1 ]], dtype = np .int64 )
7068 inputs = [
0 commit comments