Skip to content

Commit 14703f8

Browse files
authored
Merge pull request #17 from drucker/fix/check-invalid-filepath
Merged!
2 parents ada62e0 + 51703e6 commit 14703f8

2 files changed

Lines changed: 76 additions & 11 deletions

File tree

drucker/drucker_dashboard_servicer.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def on_error(self, error: Exception):
9999
self.logger.error(str(error))
100100
self.logger.error(traceback.format_exc())
101101

102+
def is_valid_upload_filename(self, filename: str) -> bool:
103+
if Path(filename).name == filename:
104+
return True
105+
return False
106+
102107
def ServiceInfo(self,
103108
request: drucker_pb2.ServiceInfoRequest,
104109
context: _Context
@@ -117,15 +122,20 @@ def UploadModel(self,
117122
) -> drucker_pb2.ModelResponse:
118123
""" Upload your latest ML model.
119124
"""
120-
save_path = None
125+
first_req = next(request_iterator)
126+
save_path = first_req.path
127+
if not self.is_valid_upload_filename(save_path):
128+
raise Exception(f'Error: Invalid model path specified -> {save_path}')
129+
121130
tmp_path = self.app.get_model_path(uuid.uuid4().hex)
122131
Path(tmp_path).parent.mkdir(parents=True, exist_ok=True)
123132
with open(tmp_path, 'wb') as f:
133+
f.write(first_req.data)
124134
for request in request_iterator:
125-
save_path = request.path
126-
model_data = request.data
127-
f.write(model_data)
135+
f.write(request.data)
136+
del first_req
128137
f.close()
138+
129139
model_path = self.app.get_model_path(save_path)
130140
Path(model_path).parent.mkdir(parents=True, exist_ok=True)
131141
shutil.move(tmp_path, model_path)
@@ -139,17 +149,20 @@ def SwitchModel(self,
139149
) -> drucker_pb2.ModelResponse:
140150
""" Switch your ML model to run.
141151
"""
152+
if not self.is_valid_upload_filename(request.path):
153+
raise Exception(f'Error: Invalid model path specified -> {request.path}')
154+
142155
model_assignment = self.app.db.session.query(ModelAssignment).filter(ModelAssignment.service_name == self.app.config.SERVICE_NAME).one()
143156
model_assignment.model_path = request.path
144157
model_assignment.first_boot = False
145158
self.app.db.session.commit()
146-
model_path = self.app.get_model_path()
147159

148160
# :TODO: Use enum for SERVICE_INFRA
149161
if self.app.config.SERVICE_INFRA == "kubernetes":
150162
pass
151163
elif self.app.config.SERVICE_INFRA == "default":
152-
self.app.load_model(model_path)
164+
self.app.model_path = self.app.get_model_path()
165+
self.app.load_model()
153166

154167
return drucker_pb2.ModelResponse(status=1,
155168
message='Success: Switching model file.')
@@ -163,6 +176,8 @@ def EvaluateModel(self,
163176
"""
164177
first_req = next(request_iterator)
165178
save_path = first_req.data_path
179+
if not self.is_valid_upload_filename(save_path):
180+
raise Exception(f'Error: Invalid evaluation file path specified -> {save_path}')
166181

167182
test_data = b''.join([first_req.data] + [r.data for r in request_iterator])
168183
result, details = self.app.evaluate(test_data)

drucker/test/test_dashboard_servicer.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class DruckerWorkerServicerTest(unittest.TestCase):
1515
def test_ServiceInfo(self):
1616
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
1717
request = drucker_pb2.ServiceInfoRequest()
18-
response = servicer.ServiceInfo(request=request, context=None)
18+
response = servicer.ServiceInfo(request, Mock())
1919
self.assertEqual(response.application_name, 'test')
2020
self.assertEqual(response.service_name, 'test-001')
2121
self.assertEqual(response.service_level, 'development')
@@ -27,12 +27,13 @@ def test_ServiceInfo(self):
2727
def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
2828
# mock setting
2929
mock_path_class.return_value = Mock()
30+
mock_path_class.return_value.name = 'my_path'
3031
mock_shutil.move.return_value = True
3132
mock_uuid.uuid4.return_value = Mock(hex='my_uuid')
3233

3334
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
3435
requests = iter(drucker_pb2.UploadModelRequest(path='my_path', data=b'data') for _ in range(1, 3))
35-
response = servicer.UploadModel(request_iterator=requests, context=None)
36+
response = servicer.UploadModel(requests, Mock())
3637

3738
tmp_path = './test-model/test/my_uuid'
3839
save_path = './test-model/test/my_path'
@@ -44,6 +45,23 @@ def test_UploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
4445
], any_order=True)
4546
mock_shutil.move.assert_called_once_with(tmp_path, save_path)
4647

48+
@patch('drucker.drucker_dashboard_servicer.uuid')
49+
@patch('drucker.drucker_dashboard_servicer.shutil')
50+
@patch('drucker.drucker_dashboard_servicer.Path')
51+
@patch("builtins.open", new_callable=mock_open)
52+
def test_InvalidUploadModel(self, mock_file, mock_path_class, mock_shutil, mock_uuid):
53+
# mock setting
54+
mock_path_class.return_value = Mock()
55+
mock_path_class.return_value.name = 'my_path'
56+
mock_shutil.move.return_value = True
57+
mock_uuid.uuid4.return_value = Mock(hex='my_uuid')
58+
59+
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
60+
requests = iter(drucker_pb2.UploadModelRequest(path='../../../my_path', data=b'data') for _ in range(1, 3))
61+
response = servicer.UploadModel(requests, Mock())
62+
63+
self.assertEqual(response.status, 0)
64+
4765
@patch('drucker.test.DummyApp')
4866
def test_SwitchModel(self, mock_app):
4967
# mock setting
@@ -52,10 +70,25 @@ def test_SwitchModel(self, mock_app):
5270

5371
servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app)
5472
request = drucker_pb2.SwitchModelRequest(path='my_path')
55-
response = servicer.SwitchModel(request=request, context=None)
73+
response = servicer.SwitchModel(request, Mock())
5674

5775
self.assertEqual(response.status, 1)
58-
mock_app.load_model.assert_called_once_with('test/my_path')
76+
mock_app.load_model.assert_called_once_with()
77+
78+
@patch('drucker.test.DummyApp')
79+
@patch('drucker.drucker_dashboard_servicer.Path')
80+
def test_InvalidSwitchModel(self, mock_path_class, mock_app):
81+
# mock setting
82+
mock_path_class.return_value = Mock()
83+
mock_path_class.return_value.name = 'my_path'
84+
mock_app.get_model_path.return_value = 'test/my_path'
85+
mock_app.config.SERVICE_INFRA = 'default'
86+
87+
servicer = DruckerDashboardServicer(logger=system_logger, app=mock_app)
88+
request = drucker_pb2.SwitchModelRequest(path='../../my_path')
89+
response = servicer.SwitchModel(request, Mock())
90+
91+
self.assertEqual(response.status, 0)
5992

6093
@patch("builtins.open", new_callable=mock_open)
6194
@patch('drucker.drucker_dashboard_servicer.pickle')
@@ -67,7 +100,7 @@ def test_EvalauteModel(self, mock_pickle, mock_file):
67100

68101
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
69102
requests = iter(drucker_pb2.EvaluateModelRequest(data_path='my_path', data=b'data_') for _ in range(1, 3))
70-
response = servicer.EvaluateModel(request_iterator=requests, context=None)
103+
response = servicer.EvaluateModel(requests, Mock())
71104

72105
self.assertEqual(round(response.metrics.num, 3), eval_result.num)
73106
self.assertEqual(round(response.metrics.accuracy, 3), eval_result.accuracy)
@@ -83,6 +116,23 @@ def test_EvalauteModel(self, mock_pickle, mock_file):
83116
call("./eval/test/my_path_eval_detail.pkl", "wb")
84117
], any_order=True)
85118

119+
@patch("builtins.open", new_callable=mock_open)
120+
@patch('drucker.drucker_dashboard_servicer.pickle')
121+
@patch('drucker.drucker_dashboard_servicer.Path')
122+
def test_InvalidEvalauteModel(self, mock_path_class, mock_pickle, mock_file):
123+
# mock setting
124+
mock_path_class.return_value = Mock()
125+
mock_path_class.return_value.name = 'my_path'
126+
eval_result = EvaluateResult(1, 0.8, [0.7], [0.6], [0.5], {'dummy': 0.4})
127+
details = [EvaluateDetail('test_input', 'test_label', PredictResult('pre_label', 0.9), False)]
128+
app.evaluate = Mock(return_value=(eval_result, details))
129+
130+
servicer = DruckerDashboardServicer(logger=system_logger, app=app)
131+
requests = iter(drucker_pb2.EvaluateModelRequest(data_path='../../my_path', data=b'data_') for _ in range(1, 3))
132+
response = servicer.EvaluateModel(requests, Mock())
133+
134+
self.assertEqual(response.metrics.num, 0)
135+
86136
def test_error_handling(self):
87137
# mock setting
88138
app.get_model_path = Mock(side_effect=Exception('dummy exception'))

0 commit comments

Comments
 (0)