11import logging
2+ import tempfile
23import unittest
3- from unittest import mock
44
55import numpy as np
66import pandas as pd
@@ -235,9 +235,7 @@ def test_verbose(self):
235235 self .assertIn ("f1-score " , log_output )
236236 self .assertIn ("F1 Score: " , log_output )
237237
238- @mock .patch ("dataprofiler.labelers.labeler_utils.classification_report" )
239- @mock .patch ("pandas.DataFrame" )
240- def test_save_conf_mat (self , mock_dataframe , mock_report ):
238+ def test_save_conf_mat (self ):
241239
242240 # ideally mock out the actual contents written to file, but
243241 # would be difficult to get this completely worked out.
@@ -248,28 +246,25 @@ def test_save_conf_mat(self, mock_dataframe, mock_report):
248246 [0 , 1 , 2 ],
249247 ]
250248 )
251- expected_row_col_names = dict (
252- columns = ["pred:PAD" , "pred:UNKNOWN" , "pred:OTHER" ],
253- index = ["true:PAD" , "true:UNKNOWN" , "true:OTHER" ],
254- )
255- mock_instance_df = mock .Mock (spec = pd .DataFrame )()
256- mock_dataframe .return_value = mock_instance_df
257-
258- # still omit bc confusion mat should include all despite omit
259- f1 , f1_report = labeler_utils .evaluate_accuracy (
260- self .y_pred ,
261- self .y_true ,
262- self .num_labels ,
263- self .reverse_label_mapping ,
264- omitted_labels = ["PAD" ],
265- verbose = False ,
266- confusion_matrix_file = "test.csv" ,
267- )
249+ expected_columns = ["pred:PAD" , "pred:UNKNOWN" , "pred:OTHER" ]
250+ expected_index = ["true:PAD" , "true:UNKNOWN" , "true:OTHER" ]
268251
269- self .assertTrue ((mock_dataframe .call_args [0 ][0 ] == expected_conf_mat ).all ())
270- self .assertDictEqual (expected_row_col_names , mock_dataframe .call_args [1 ])
252+ with tempfile .NamedTemporaryFile () as tmpFile :
253+ # still omit bc confusion mat should include all despite omit
254+ f1 , f1_report = labeler_utils .evaluate_accuracy (
255+ self .y_pred ,
256+ self .y_true ,
257+ self .num_labels ,
258+ self .reverse_label_mapping ,
259+ omitted_labels = ["PAD" ],
260+ verbose = False ,
261+ confusion_matrix_file = tmpFile .name ,
262+ )
271263
272- mock_instance_df .to_csv .assert_called ()
264+ df1 = pd .read_csv (tmpFile .name , index_col = 0 )
265+ self .assertListEqual (list (df1 .columns ), expected_columns )
266+ self .assertListEqual (list (df1 .index ), expected_index )
267+ np .testing .assert_array_equal (df1 .values , expected_conf_mat )
273268
274269
275270class TestTFFunctions (unittest .TestCase ):
0 commit comments