Skip to content

Commit 51703e6

Browse files
committed
Add check is_valid_upload_filename routine to SwitchModel
1 parent 5d3c22d commit 51703e6

2 files changed

Lines changed: 23 additions & 5 deletions

File tree

drucker/drucker_dashboard_servicer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,20 @@ def SwitchModel(self,
149149
) -> drucker_pb2.ModelResponse:
150150
""" Switch your ML model to run.
151151
"""
152+
if not self.is_valid_upload_filename(request.path):
153+
raise Exception(f'Error: Invalid model path specified -> {request.path}')
154+
152155
model_assignment = self.app.db.session.query(ModelAssignment).filter(ModelAssignment.service_name == self.app.config.SERVICE_NAME).one()
153156
model_assignment.model_path = request.path
154157
model_assignment.first_boot = False
155158
self.app.db.session.commit()
156-
model_path = self.app.get_model_path()
157159

158160
# :TODO: Use enum for SERVICE_INFRA
159161
if self.app.config.SERVICE_INFRA == "kubernetes":
160162
pass
161163
elif self.app.config.SERVICE_INFRA == "default":
162-
self.app.load_model(model_path)
164+
self.app.model_path = self.app.get_model_path()
165+
self.app.load_model()
163166

164167
return drucker_pb2.ModelResponse(status=1,
165168
message='Success: Switching model file.')
@@ -174,7 +177,7 @@ def EvaluateModel(self,
174177
first_req = next(request_iterator)
175178
save_path = first_req.data_path
176179
if not self.is_valid_upload_filename(save_path):
177-
raise Exception(f'Error: Invalid evaluation file specified -> {save_path}')
180+
raise Exception(f'Error: Invalid evaluation file path specified -> {save_path}')
178181

179182
test_data = b''.join([first_req.data] + [r.data for r in request_iterator])
180183
result, details = self.app.evaluate(test_data)

drucker/test/test_dashboard_servicer.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class DruckerWorkerServicerTest(unittest.TestCase):
1515
def test_ServiceInfo(self):
1616
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
1717
request = drucker_pb2.ServiceInfoRequest()
18-
response = servicer.ServiceInfo(request=request, context=Mock())
18+
response = servicer.ServiceInfo(request, Mock())
1919
self.assertEqual(response.application_name, 'test')
2020
self.assertEqual(response.service_name, 'test-001')
2121
self.assertEqual(response.service_level, 'development')
@@ -73,7 +73,22 @@ def test_SwitchModel(self, mock_app):
7373
response = servicer.SwitchModel(request, Mock())
7474

7575
self.assertEqual(response.status, 1)
76-
mock_app.load_model.assert_called_once_with('test/my_path')
76+
mock_app.load_model.assert_called_once_with()
77+
78+
@patch('drucker.test.DummyApp')
79+
@patch('drucker.drucker_dashboard_servicer.Path')
80+
def test_InvalidSwitchModel(self, mock_path_class, mock_app):
81+
# mock setting
82+
mock_path_class.return_value = Mock()
83+
mock_path_class.return_value.name = 'my_path'
84+
mock_app.get_model_path.return_value = 'test/my_path'
85+
mock_app.config.SERVICE_INFRA = 'default'
86+
87+
servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app)
88+
request = drucker_pb2.SwitchModelRequest(path='../../my_path')
89+
response = servicer.SwitchModel(request, Mock())
90+
91+
self.assertEqual(response.status, 0)
7792

7893
@patch("builtins.open", new_callable=mock_open)
7994
@patch('drucker.drucker_dashboard_servicer.pickle')

0 commit comments

Comments
 (0)