Skip to content

Commit efbe2ad

Browse files
committed
fixing tests
1 parent 4f4980c commit efbe2ad

2 files changed

Lines changed: 9 additions & 8 deletions

File tree

mindsdb_sdk/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def describe(self, type: str = None) -> Union[pd.DataFrame, Query]:
341341
identifier = self._get_identifier()
342342
if type is not None:
343343
identifier.parts.append(type)
344+
identifier.is_quoted.append(False)
344345
ast_query = Describe(identifier)
345346

346347
sql = ast_query.to_string()

tests/test_sdk.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ def check_model(self, model, database, mock_post):
107107
assert call_args[1]['json']['params'] == params
108108

109109
# check prediction
110-
assert (pred_df == pd.DataFrame(data_out)).all().bool()
110+
assert pred_df.equals(pd.DataFrame(data_out))
111111

112112
# predict using dict
113-
pred_df = model.predict({ 'a': 1 })
114-
assert (pred_df == pd.DataFrame(data_out)).all().bool()
113+
pred_df = model.predict({'a': 1})
114+
assert pred_df.equals(pd.DataFrame(data_out))
115115

116116
# using deferred query
117117
response_mock(mock_post, pd.DataFrame(data_out)) # will be used sql/query
@@ -121,15 +121,15 @@ def check_model(self, model, database, mock_post):
121121

122122
check_sql_call(mock_post,
123123
f'SELECT m.* FROM (SELECT * FROM {query.database} (select a from t1)) AS t JOIN {model.project.name}.{model_name} AS m USING x="1"')
124-
assert (pred_df == pd.DataFrame(data_out)).all().bool()
124+
assert pred_df.equals(pd.DataFrame(data_out))
125125

126126
# using table
127127
table0 = database.tables.tbl0
128128
pred_df = model.predict(table0)
129129

130130
check_sql_call(mock_post,
131131
f'SELECT m.* FROM (SELECT * FROM {table0.db.name}.tbl0) AS t JOIN {model.project.name}.{model_name} AS m')
132-
assert (pred_df == pd.DataFrame(data_out)).all().bool()
132+
assert pred_df.equals(pd.DataFrame(data_out))
133133

134134

135135
# time series prediction
@@ -138,7 +138,7 @@ def check_model(self, model, database, mock_post):
138138

139139
check_sql_call(mock_post,
140140
f'SELECT m.* FROM (SELECT * FROM {query.database} (select * from t1 where type="house" and saledate>latest)) as t JOIN {model.project.name}.{model_name} AS m')
141-
assert (pred_df == pd.DataFrame(data_out)).all().bool()
141+
assert pred_df.equals(pd.DataFrame(data_out))
142142

143143
# ----------- model managing --------------
144144
response_mock(
@@ -524,7 +524,7 @@ def check_database(self, database, mock_post):
524524

525525
check_sql_call(mock_post, sql)
526526

527-
assert (data == result).all().bool()
527+
assert data.equals(result)
528528

529529
# test tables
530530
response_mock(mock_post, pd.DataFrame([{'name': 't1'}]))
@@ -990,7 +990,7 @@ def check_database(self, database, mock_post):
990990

991991
check_sql_call(mock_post, sql)
992992

993-
assert (data == result).all().bool()
993+
assert data.equals(result)
994994

995995
# test tables
996996
response_mock(mock_post, pd.DataFrame([{'name': 't1'}]))

0 commit comments

Comments
 (0)