Skip to content

Commit dff88de

Browse files
authored
Merge pull request #584 from VectorInstitute/test-report-module
Test report module
2 parents 11af403 + 278701d commit dff88de

1 file changed

Lines changed: 198 additions & 0 deletions

File tree

tests/cyclops/report/test_report.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from unittest import TestCase
44

55
from cyclops.report import ModelCardReport
6+
from cyclops.report.model_card.sections import ModelDetails
67

78

89
class TestModelCardReport(TestCase):
@@ -166,3 +167,200 @@ def test_log_quantitative_analysis_fairness(self):
166167
].value
167168
== 0.9
168169
)
170+
171+
def test_log_from_dict(self):
172+
"""Test log_from_dict."""
173+
data = {
174+
"datasets": "mnist",
175+
"Description": "dataset of digits from 0 to 9",
176+
"overview": "Handwritten 28x28 pixel image",
177+
}
178+
self.model_card_report.log_from_dict(data, "overview")
179+
assert (
180+
self.model_card_report._model_card.overview.overview
181+
== "Handwritten 28x28 pixel image"
182+
)
183+
184+
def test_log_version(self):
185+
"""Test log_version."""
186+
self.model_card_report._model_card.model_details = ModelDetails()
187+
self.model_card_report.log_version("1.2.0", description="Added new feature")
188+
assert (
189+
self.model_card_report._model_card.model_details.version.version == "1.2.0"
190+
)
191+
192+
assert (
193+
self.model_card_report._model_card.model_details.version.description
194+
== "Added new feature"
195+
)
196+
197+
def test_log_license(self):
198+
"""Test adding license to licenses."""
199+
self.model_card_report.log_license("Apache-2.0")
200+
assert (
201+
self.model_card_report._model_card.model_details.licenses[0].identifier
202+
== "Apache-2.0"
203+
)
204+
205+
def test_log_citation(self):
206+
"""Test adding citation to model details."""
207+
cite = """@misc{vaswani2023attention,
208+
title={Attention Is All You Need},
209+
author={Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
210+
year={2023},
211+
eprint={1706.03762},
212+
archivePrefix={arXiv},
213+
primaryClass={cs.CL}
214+
}"""
215+
self.model_card_report.log_citation(cite)
216+
assert (
217+
self.model_card_report._model_card.model_details.citations[0].content
218+
== cite
219+
)
220+
221+
def test_log_reference(self):
222+
"""Test adding reference to model details."""
223+
ref = (
224+
"https://vectorinstitute.github.io/cyclops/api/reference/api/evaluator.html"
225+
)
226+
self.model_card_report.log_reference(ref)
227+
assert (
228+
self.model_card_report._model_card.model_details.references[0].link == ref
229+
)
230+
231+
def test_log_regulation(self):
232+
"""Test adding regulations to model details."""
233+
reg = "sample regulation requirement"
234+
self.model_card_report.log_regulation(reg)
235+
assert (
236+
self.model_card_report._model_card.model_details.regulatory_requirements[
237+
0
238+
].regulation
239+
== reg
240+
)
241+
242+
def test_log_model_param(self):
243+
"""Test logging model parameters."""
244+
params = {"w_1": [1.0, 0.49], "b_1": [0.32, 0.4]}
245+
self.model_card_report.log_model_parameters(params)
246+
assert self.model_card_report._model_card.model_parameters.b_1 == params["b_1"]
247+
assert self.model_card_report._model_card.model_parameters.w_1 == params["w_1"]
248+
249+
def test_log_dataset(self):
250+
"""Test logging information about the dataset."""
251+
descr = "dataset of digits from 0 to 9"
252+
cite = """@article{deng2012mnist,
253+
title={The mnist database of handwritten digit images for machine learning research},
254+
author={Deng, Li},
255+
journal={IEEE Signal Processing Magazine},
256+
volume={29},
257+
number={6},
258+
pages={141--142},
259+
year={2012},
260+
publisher={IEEE}
261+
}"""
262+
self.model_card_report.log_dataset(description=descr, citation=cite)
263+
assert self.model_card_report._model_card.datasets.data[0].description == descr
264+
265+
assert (
266+
self.model_card_report._model_card.datasets.data[0].citation.content == cite
267+
)
268+
269+
def test_log_use_case(self):
270+
"""Test adding a use case to a section of the report."""
271+
usecase = "Medical imaging and segmentaion"
272+
self.model_card_report.log_use_case(usecase, kind="primary")
273+
assert (
274+
self.model_card_report._model_card.considerations.use_cases[0].description
275+
== usecase
276+
)
277+
278+
def test_log_risk(self):
279+
"""Test adding a risk to a section of the report."""
280+
risk = "Ethical Considerations #2"
281+
mitigation = "Mitigation strategy #2"
282+
self.model_card_report.log_risk(risk, mitigation)
283+
284+
assert (
285+
self.model_card_report._model_card.considerations.ethical_considerations[
286+
0
287+
].risk
288+
== risk
289+
)
290+
assert (
291+
self.model_card_report._model_card.considerations.ethical_considerations[
292+
0
293+
].mitigation_strategy
294+
== mitigation
295+
)
296+
297+
def test_fairness_assessment(self):
298+
"""Test adding a fairness assessment to a section of the report."""
299+
affected_group = "Group #3"
300+
benefit = "Benefit #2"
301+
harm = "Harm #5"
302+
mitigation = "Mitigation strategy #2"
303+
304+
self.model_card_report.log_fairness_assessment(
305+
affected_group, benefit, harm, mitigation
306+
)
307+
308+
assert (
309+
self.model_card_report._model_card.considerations.fairness_assessment[
310+
0
311+
].affected_group
312+
== affected_group
313+
)
314+
assert (
315+
self.model_card_report._model_card.considerations.fairness_assessment[
316+
0
317+
].benefits
318+
== benefit
319+
)
320+
assert (
321+
self.model_card_report._model_card.considerations.fairness_assessment[
322+
0
323+
].harms
324+
== harm
325+
)
326+
assert (
327+
self.model_card_report._model_card.considerations.fairness_assessment[
328+
0
329+
].mitigation_strategy
330+
== mitigation
331+
)
332+
333+
def test_export(self):
334+
"""Test exporing model card report to html file."""
335+
affected_group = "Group #3"
336+
benefit = "Benefit #2"
337+
harm = "Harm #5"
338+
mitigation = "Mitigation strategy #2"
339+
340+
self.model_card_report.log_fairness_assessment(
341+
affected_group, benefit, harm, mitigation
342+
)
343+
self.model_card_report.log_quantitative_analysis(
344+
analysis_type="performance",
345+
name="BinaryAccuracy",
346+
description="Accuracy of the model on the test set",
347+
value=0.85,
348+
metric_slice="overall",
349+
decision_threshold=0.7,
350+
pass_fail_thresholds=[0.6, 0.65, 0.7],
351+
pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)],
352+
)
353+
self.model_card_report.log_quantitative_analysis(
354+
analysis_type="performance",
355+
name="BinaryF1Score",
356+
value=0.65,
357+
metric_slice="overall",
358+
decision_threshold=0.8,
359+
description="F1 score of the model on the test set",
360+
pass_fail_thresholds=[0.9, 0.85, 0.8],
361+
pass_fail_threshold_fns=[lambda x, t: x >= t for _ in range(3)],
362+
)
363+
self.model_card_report.log_owner(name="John Doe")
364+
365+
report_path = self.model_card_report.export(interactive=False, save_json=False)
366+
assert isinstance(report_path, str)

0 commit comments

Comments
 (0)