Skip to content

Commit b81ff8c

Browse files
Merge pull request #304 from SubstraFoundation/mpl-approaches-test
Test all mpl approaches.
2 parents c034875 + 8448117 commit b81ff8c

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

tests/ml_perf_end_to_end_tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import pandas as pd
1010

11+
from mplc import multi_partner_learning
1112
from mplc.corruption import Duplication
1213
from mplc.experiment import Experiment
1314
from mplc.scenario import Scenario
@@ -48,3 +49,20 @@ def test_titanic(self):
4849
assert np.min(titanic_scenario_1.mpl.history.score) > 0.65
4950
result = pd.read_csv(exp.experiment_path / 'results.csv')
5051
assert (result.groupby('scenario_index').mean().mpl_test_score > 0.65).all()
52+
53+
def test_all_mpl_approaches(self):
54+
"""
55+
Test all the mpl approaches
56+
"""
57+
58+
exp = Experiment()
59+
mpl_approaches = multi_partner_learning.MULTI_PARTNER_LEARNING_APPROACHES.copy()
60+
61+
for approach in mpl_approaches:
62+
exp.add_scenario(Scenario(2, [0.25, 0.75], epoch_count=2, minibatch_count=2, dataset_name='mnist',
63+
dataset_proportion=0.1, multi_partner_learning_approach=approach,
64+
gradient_updates_per_pass_count=3))
65+
exp.run()
66+
67+
df = exp.result
68+
assert len(df) == len(mpl_approaches)

0 commit comments

Comments
 (0)