@@ -107,11 +107,11 @@ def check_model(self, model, database, mock_post):
107107 assert call_args [1 ]['json' ]['params' ] == params
108108
109109 # check prediction
110- assert ( pred_df == pd .DataFrame (data_out )). all (). bool ( )
110+ assert pred_df . equals ( pd .DataFrame (data_out ))
111111
112112 # predict using dict
113- pred_df = model .predict ({ 'a' : 1 })
114- assert ( pred_df == pd .DataFrame (data_out )). all (). bool ( )
113+ pred_df = model .predict ({'a' : 1 })
114+ assert pred_df . equals ( pd .DataFrame (data_out ))
115115
116116 # using deferred query
117117 response_mock (mock_post , pd .DataFrame (data_out )) # will be used sql/query
@@ -121,15 +121,15 @@ def check_model(self, model, database, mock_post):
121121
122122 check_sql_call (mock_post ,
123123 f'SELECT m.* FROM (SELECT * FROM { query .database } (select a from t1)) AS t JOIN { model .project .name } .{ model_name } AS m USING x="1"' )
124- assert ( pred_df == pd .DataFrame (data_out )). all (). bool ( )
124+ assert pred_df . equals ( pd .DataFrame (data_out ))
125125
126126 # using table
127127 table0 = database .tables .tbl0
128128 pred_df = model .predict (table0 )
129129
130130 check_sql_call (mock_post ,
131131 f'SELECT m.* FROM (SELECT * FROM { table0 .db .name } .tbl0) AS t JOIN { model .project .name } .{ model_name } AS m' )
132- assert ( pred_df == pd .DataFrame (data_out )). all (). bool ( )
132+ assert pred_df . equals ( pd .DataFrame (data_out ))
133133
134134
135135 # time series prediction
@@ -138,7 +138,7 @@ def check_model(self, model, database, mock_post):
138138
139139 check_sql_call (mock_post ,
140140 f'SELECT m.* FROM (SELECT * FROM { query .database } (select * from t1 where type="house" and saledate>latest)) as t JOIN { model .project .name } .{ model_name } AS m' )
141- assert ( pred_df == pd .DataFrame (data_out )). all (). bool ( )
141+ assert pred_df . equals ( pd .DataFrame (data_out ))
142142
143143 # ----------- model managing --------------
144144 response_mock (
@@ -524,7 +524,7 @@ def check_database(self, database, mock_post):
524524
525525 check_sql_call (mock_post , sql )
526526
527- assert ( data == result ). all (). bool ( )
527+ assert data . equals ( result )
528528
529529 # test tables
530530 response_mock (mock_post , pd .DataFrame ([{'name' : 't1' }]))
@@ -990,7 +990,7 @@ def check_database(self, database, mock_post):
990990
991991 check_sql_call (mock_post , sql )
992992
993- assert ( data == result ). all (). bool ( )
993+ assert data . equals ( result )
994994
995995 # test tables
996996 response_mock (mock_post , pd .DataFrame ([{'name' : 't1' }]))
0 commit comments