66from core .feature_flags import flag_set
77from core .permissions import ViewClassPermission , all_permissions
88from django .conf import settings
9+ from django .http import Http404
910from django .utils .decorators import method_decorator
1011from django_filters .rest_framework import DjangoFilterBackend
1112from drf_yasg .utils import swagger_auto_schema
@@ -76,10 +77,11 @@ class MLBackendListAPI(generics.ListCreateAPIView):
7677 def get_queryset (self ):
7778 project_pk = self .request .query_params .get ('project' )
7879 project = generics .get_object_or_404 (Project , pk = project_pk )
80+
7981 self .check_object_permissions (self .request , project )
80- ml_backends = MLBackend . objects . filter ( project_id = project . id )
81- for mlb in ml_backends :
82- mlb . update_state ()
82+
83+ ml_backends = project . update_ml_backends_state ()
84+
8385 return ml_backends
8486
8587 def perform_create (self , serializer ):
@@ -202,6 +204,58 @@ def post(self, request, *args, **kwargs):
202204 return Response (status = status .HTTP_200_OK )
203205
204206
207+ @method_decorator (
208+ name = 'post' ,
209+ decorator = swagger_auto_schema (
210+ tags = ['Machine Learning' ],
211+ operation_summary = 'Predict' ,
212+ operation_description = """
213+ After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data
214+ """ ,
215+ manual_parameters = [
216+ openapi .Parameter (
217+ name = 'id' ,
218+ type = openapi .TYPE_INTEGER ,
219+ in_ = openapi .IN_PATH ,
220+ description = 'A unique integer value identifying this ML backend.' ,
221+ ),
222+ ],
223+ responses = {
224+ 200 : openapi .Response (title = 'Predicting OK' , description = 'Predicting has successfully started.' ),
225+ 500 : openapi .Response (
226+ description = 'Predicting error' ,
227+ schema = openapi .Schema (
228+ title = 'Error message' ,
229+ description = 'Error message' ,
230+ type = openapi .TYPE_STRING ,
231+ example = 'Server responded with an error.' ,
232+ ),
233+ ),
234+ },
235+ ),
236+ )
237+ class MLBackendPredictTestAPI (APIView ):
238+ serializer_class = MLBackendSerializer
239+ permission_required = all_permissions .projects_change
240+
241+ def post (self , request , * args , ** kwargs ):
242+ ml_backend = generics .get_object_or_404 (MLBackend , pk = self .kwargs ['pk' ])
243+ self .check_object_permissions (self .request , ml_backend )
244+
245+ random = request .query_params .get ('random' , False )
246+ if random :
247+ task = Task .get_random (project = ml_backend .project )
248+ if not task :
249+ raise Http404
250+
251+ kwargs = ml_backend ._predict (task )
252+ return Response (** kwargs )
253+
254+ # TODO this needs to be implemented and needs to have a specific task param
255+ ml_backend .predict ()
256+ return Response (status = status .HTTP_200_OK )
257+
258+
205259@method_decorator (
206260 name = 'post' ,
207261 decorator = swagger_auto_schema (
@@ -227,29 +281,64 @@ def post(self, request, *args, **kwargs):
227281 ),
228282)
229283class MLBackendInteractiveAnnotating (APIView ):
284+ """ """
230285
231286 permission_required = all_permissions .tasks_view
232287
288+ def _error_response (self , message , log_function = logger .info ):
289+ """ """
290+ log_function (message )
291+ return Response ({'errors' : [message ]}, status = status .HTTP_200_OK )
292+
293+ def _get_task (self , ml_backend , validated_data ):
294+ """ """
295+ return generics .get_object_or_404 (Task , pk = validated_data ['task' ], project = ml_backend .project )
296+
297+ def _get_credentials (self , request , context , project ):
298+ """ """
299+ if flag_set ('ff_back_dev_2362_project_credentials_060722_short' , request .user ):
300+ context .update (
301+ project_credentials_login = project .task_data_login ,
302+ project_credentials_password = project .task_data_password ,
303+ )
304+ return context
305+
306+ def _get_ml_results (self , ml_api_result ):
307+ """ """
308+ results = ml_api_result .response .get ('results' , [None ])
309+ if isinstance (results , list ) and len (results ) >= 1 :
310+ return results [0 ]
311+
312+ return None
313+
233314 def post (self , request , * args , ** kwargs ):
315+ """ """
234316 ml_backend = generics .get_object_or_404 (MLBackend , pk = self .kwargs ['pk' ])
235- self .check_object_permissions (self .request , ml_backend )
317+ self .check_object_permissions (request , ml_backend )
318+
236319 serializer = MLInteractiveAnnotatingRequest (data = request .data )
237320 serializer .is_valid (raise_exception = True )
238- validated_data = serializer .validated_data
239321
240- task = generics . get_object_or_404 ( Task , pk = validated_data [ 'task' ], project = ml_backend . project )
241- context = validated_data .get ('context' )
322+ task = self . _get_task ( ml_backend , serializer . validated_data )
323+ context = self . _get_credentials ( request , serializer . validated_data .get ('context' , {}), task . project )
242324
243- if flag_set ('ff_back_dev_2362_project_credentials_060722_short' , request .user ):
244- context ['project_credentials_login' ] = task .project .task_data_login
245- context ['project_credentials_password' ] = task .project .task_data_password
325+ ml_api_result = ml_backend .interactive_annotating (task , context , user = self .request .user )
326+
327+ if ml_api_result .is_error :
328+ message = f'Prediction not created for project { self } : { ml_api_result .error_message } '
329+ return self ._error_response (message )
330+
331+ if not isinstance (ml_api_result .response , dict ) or 'results' not in ml_api_result .response :
332+ message = f'Incorrect response from ML service it must be a dict and contain "results" key: { ml_api_result .response } '
333+ return self ._error_response (message )
334+
335+ ml_results = self ._get_ml_results (ml_api_result )
246336
247- result = ml_backend .interactive_annotating (task , context , user = request .user )
337+ if not ml_results :
338+ message = f'ML backend has to return a list with at least 1 annotation but it returned: { type (ml_results )} '
339+ return self ._error_response (message , logger .warning )
248340
249- return Response (
250- result ,
251- status = status .HTTP_200_OK ,
252- )
341+ return Response ({'data' : ml_results }, status = status .HTTP_200_OK )
253342
254343
255344@method_decorator (
0 commit comments