Skip to content

Commit d1021d6

Browse files
niklubnik
andauthored
Revert "feat: HUMSIG-37: Enhancing ML backend connection experience (… (#5569)
Revert "feat: HUMSIG-37: Enhancing ML backend connection experience (#5546)" This reverts commit b198728. Co-authored-by: nik <nik@heartex.net>
1 parent 44e7325 commit d1021d6

92 files changed

Lines changed: 2615 additions & 4032 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-dev.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,3 @@ repos:
99
rev: v0.9.1
1010
hooks:
1111
- id: blue
12-
args: [ --verbose ]

label_studio/core/all_urls.json

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -946,12 +946,6 @@
946946
"module": "ml.api.MLBackendTrainAPI",
947947
"name": "ml:api:ml-train",
948948
"decorators": ""
949-
},
950-
{
951-
"url": "/api/ml/<int:pk>/predict/test",
952-
"module": "ml.api.MLBackendPredictAPI",
953-
"name": "ml:api:ml-predict-test",
954-
"decorators": ""
955949
},
956950
{
957951
"url": "/api/ml/<int:pk>/interactive-annotating",

label_studio/data_manager/actions/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from core.permissions import AllPermissions
77
from core.redis import start_job_async_or_sync
88
from core.utils.common import load_func
9-
from data_manager.functions import retrieve_predictions
9+
from data_manager.functions import evaluate_predictions
1010
from django.conf import settings
1111
from projects.models import Project
1212
from tasks.functions import update_tasks_counters
@@ -24,7 +24,7 @@ def retrieve_tasks_predictions(project, queryset, **kwargs):
2424
:param project: project instance
2525
:param queryset: filtered tasks db queryset
2626
"""
27-
retrieve_predictions(queryset)
27+
evaluate_predictions(queryset)
2828
return {'processed_items': queryset.count(), 'detail': 'Retrieved ' + str(queryset.count()) + ' predictions'}
2929

3030

@@ -138,7 +138,6 @@ def async_project_summary_recalculation(tasks_ids_list, project_id):
138138
'title': 'Retrieve Predictions',
139139
'order': 90,
140140
'dialog': {
141-
'modal_title': 'Retrieve Predictions',
142141
'text': 'Send the selected tasks to all ML backends connected to the project.'
143142
'This operation might be abruptly interrupted due to a timeout. '
144143
'The recommended way to get predictions is to update tasks using the Label Studio API.'

label_studio/data_manager/actions/next_task.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def next_task(project, queryset, **kwargs):
3232
# serialize task
3333
context = {'request': request, 'project': project, 'resolve_uri': True, 'annotations': False}
3434
serializer = NextTaskSerializer(next_task, context=context)
35-
3635
response = serializer.data
3736
response['queue'] = queue_info
3837
return response

label_studio/data_manager/actions/predictions_to_annotations.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def predictions_to_annotations_form(user, project):
8080
{
8181
'type': 'select',
8282
'name': 'model_version',
83-
'label': 'Choose predictions',
83+
'label': 'Choose a model',
8484
'options': versions,
8585
}
8686
],
@@ -95,10 +95,8 @@ def predictions_to_annotations_form(user, project):
9595
'title': 'Create Annotations From Predictions',
9696
'order': 91,
9797
'dialog': {
98-
'modal_title': 'Create Annotations From Predictions',
99-
'text': 'Create annotations from predictions using selected predictions set '
100-
'for each selected task.'
101-
'Your account will be assigned as an owner to those annotations. ',
98+
'text': 'This action will create new annotations from predictions with the selected model version '
99+
'for each selected task.',
102100
'type': 'confirm',
103101
'form': predictions_to_annotations_form,
104102
},

label_studio/data_manager/api.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from core.utils.common import int_from_request, load_func
99
from core.utils.params import bool_from_request
1010
from data_manager.actions import get_all_actions, perform_action
11-
from data_manager.functions import get_prepare_params, get_prepared_queryset
11+
from data_manager.functions import evaluate_predictions, get_prepare_params, get_prepared_queryset
1212
from data_manager.managers import get_fields_for_evaluation
1313
from data_manager.models import View
1414
from data_manager.serializers import DataManagerTaskSerializer, ViewResetSerializer, ViewSerializer
@@ -255,15 +255,11 @@ def get(self, request):
255255
# keep ids ordering
256256
page = [tasks_by_ids[_id] for _id in ids]
257257

258-
# TODO MM TODO this needs a discussion, because I'd expect
259-
# people to retrieve manually instead on DM load, plus it
260-
# will slow down initial DM load
261-
262258
# retrieve ML predictions if tasks don't have them
263-
# if not review and project.retrieve_predictions_automatically:
264-
# tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
265-
# retrieve_predictions(tasks_for_predictions)
266-
# [tasks_by_ids[_id].refresh_from_db() for _id in ids]
259+
if not review and project.evaluate_predictions_automatically:
260+
tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
261+
evaluate_predictions(tasks_for_predictions)
262+
[tasks_by_ids[_id].refresh_from_db() for _id in ids]
267263

268264
if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
269265
serializer = self.task_serializer_class(
@@ -275,18 +271,13 @@ def get(self, request):
275271
else:
276272
serializer = self.task_serializer_class(page, many=True, context=context)
277273
return self.get_paginated_response(serializer.data)
278-
279-
# TODO
280274
# all tasks
281-
# if project.retrieve_predictions_automatically:
282-
# retrieve_predictions(queryset.filter(predictions__isnull=True))
283-
275+
if project.evaluate_predictions_automatically:
276+
evaluate_predictions(queryset.filter(predictions__isnull=True))
284277
queryset = Task.prepared.annotate_queryset(
285278
queryset, fields_for_evaluation=fields_for_evaluation, all_fields=all_fields, request=request
286279
)
287-
288280
serializer = self.task_serializer_class(queryset, many=True, context=context)
289-
290281
return Response(serializer.data)
291282

292283

@@ -386,7 +377,6 @@ def post(self, request):
386377
# perform action and return the result dict
387378
kwargs = {'request': request} # pass advanced params to actions
388379
result = perform_action(action_id, project, queryset, request.user, **kwargs)
389-
390380
code = result.pop('response_code', 200)
391381

392382
return Response(result, status=code)

label_studio/data_manager/functions.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -314,23 +314,16 @@ def get_prepared_queryset(request, project):
314314
return queryset
315315

316316

317-
def retrieve_predictions(tasks, backend=None):
318-
"""Call ML backend to retrieve predictions with the task queryset as an input"""
317+
def evaluate_predictions(tasks):
318+
"""Call ML backend for prediction evaluation of the task queryset"""
319319
if not tasks:
320320
return
321321

322-
if not backend:
323-
project = tasks[0].project
324-
backend = project.ml_backends.first()
322+
project = tasks[0].project
325323

326-
# IMPORTANT change here, ml_backends.all => ml_backends.first
327-
# we are using only one ML backend, not multiple
328-
if backend:
329-
return backend.predict_and_save(tasks=tasks)
330-
331-
# for ml_backend in project.ml_backends.first():
332-
# # tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
333-
# ml_backend.predict_and_save(tasks=tasks)
324+
for ml_backend in project.ml_backends.all():
325+
# tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
326+
ml_backend.predict_tasks(tasks)
334327

335328

336329
def filters_ordering_selected_items_exist(data):

label_studio/ml/api.py

Lines changed: 15 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from core.feature_flags import flag_set
77
from core.permissions import ViewClassPermission, all_permissions
88
from django.conf import settings
9-
from django.http import Http404
109
from django.utils.decorators import method_decorator
1110
from django_filters.rest_framework import DjangoFilterBackend
1211
from drf_yasg.utils import swagger_auto_schema
@@ -77,11 +76,10 @@ class MLBackendListAPI(generics.ListCreateAPIView):
7776
def get_queryset(self):
7877
project_pk = self.request.query_params.get('project')
7978
project = generics.get_object_or_404(Project, pk=project_pk)
80-
8179
self.check_object_permissions(self.request, project)
82-
83-
ml_backends = project.update_ml_backends_state()
84-
80+
ml_backends = MLBackend.objects.filter(project_id=project.id)
81+
for mlb in ml_backends:
82+
mlb.update_state()
8583
return ml_backends
8684

8785
def perform_create(self, serializer):
@@ -204,58 +202,6 @@ def post(self, request, *args, **kwargs):
204202
return Response(status=status.HTTP_200_OK)
205203

206204

207-
@method_decorator(
208-
name='post',
209-
decorator=swagger_auto_schema(
210-
tags=['Machine Learning'],
211-
operation_summary='Predict',
212-
operation_description="""
213-
After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data
214-
""",
215-
manual_parameters=[
216-
openapi.Parameter(
217-
name='id',
218-
type=openapi.TYPE_INTEGER,
219-
in_=openapi.IN_PATH,
220-
description='A unique integer value identifying this ML backend.',
221-
),
222-
],
223-
responses={
224-
200: openapi.Response(title='Predicting OK', description='Predicting has successfully started.'),
225-
500: openapi.Response(
226-
description='Predicting error',
227-
schema=openapi.Schema(
228-
title='Error message',
229-
description='Error message',
230-
type=openapi.TYPE_STRING,
231-
example='Server responded with an error.',
232-
),
233-
),
234-
},
235-
),
236-
)
237-
class MLBackendPredictTestAPI(APIView):
238-
serializer_class = MLBackendSerializer
239-
permission_required = all_permissions.projects_change
240-
241-
def post(self, request, *args, **kwargs):
242-
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
243-
self.check_object_permissions(self.request, ml_backend)
244-
245-
random = request.query_params.get('random', False)
246-
if random:
247-
task = Task.get_random(project=ml_backend.project)
248-
if not task:
249-
raise Http404
250-
251-
kwargs = ml_backend._predict(task)
252-
return Response(**kwargs)
253-
254-
# TODO this needs to be implemented and needs to have a specific task param
255-
ml_backend.predict()
256-
return Response(status=status.HTTP_200_OK)
257-
258-
259205
@method_decorator(
260206
name='post',
261207
decorator=swagger_auto_schema(
@@ -281,64 +227,29 @@ def post(self, request, *args, **kwargs):
281227
),
282228
)
283229
class MLBackendInteractiveAnnotating(APIView):
284-
""" """
285230

286231
permission_required = all_permissions.tasks_view
287232

288-
def _error_response(self, message, log_function=logger.info):
289-
""" """
290-
log_function(message)
291-
return Response({'errors': [message]}, status=status.HTTP_200_OK)
292-
293-
def _get_task(self, ml_backend, validated_data):
294-
""" """
295-
return generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
296-
297-
def _get_credentials(self, request, context, project):
298-
""" """
299-
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
300-
context.update(
301-
project_credentials_login=project.task_data_login,
302-
project_credentials_password=project.task_data_password,
303-
)
304-
return context
305-
306-
def _get_ml_results(self, ml_api_result):
307-
""" """
308-
results = ml_api_result.response.get('results', [None])
309-
if isinstance(results, list) and len(results) >= 1:
310-
return results[0]
311-
312-
return None
313-
314233
def post(self, request, *args, **kwargs):
315-
""" """
316234
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
317-
self.check_object_permissions(request, ml_backend)
318-
235+
self.check_object_permissions(self.request, ml_backend)
319236
serializer = MLInteractiveAnnotatingRequest(data=request.data)
320237
serializer.is_valid(raise_exception=True)
238+
validated_data = serializer.validated_data
321239

322-
task = self._get_task(ml_backend, serializer.validated_data)
323-
context = self._get_credentials(request, serializer.validated_data.get('context', {}), task.project)
324-
325-
ml_api_result = ml_backend.interactive_annotating(task, context, user=self.request.user)
240+
task = generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
241+
context = validated_data.get('context')
326242

327-
if ml_api_result.is_error:
328-
message = f'Prediction not created for project {self}: {ml_api_result.error_message}'
329-
return self._error_response(message)
330-
331-
if not isinstance(ml_api_result.response, dict) or 'results' not in ml_api_result.response:
332-
message = f'Incorrect response from ML service it must be a dict and contain "results" key: {ml_api_result.response}'
333-
return self._error_response(message)
334-
335-
ml_results = self._get_ml_results(ml_api_result)
243+
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
244+
context['project_credentials_login'] = task.project.task_data_login
245+
context['project_credentials_password'] = task.project.task_data_password
336246

337-
if not ml_results:
338-
message = f'ML backend has to return a list with at least 1 annotation but it returned: {type(ml_results)}'
339-
return self._error_response(message, logger.warning)
247+
result = ml_backend.interactive_annotating(task, context, user=request.user)
340248

341-
return Response({'data': ml_results}, status=status.HTTP_200_OK)
249+
return Response(
250+
result,
251+
status=status.HTTP_200_OK,
252+
)
342253

343254

344255
@method_decorator(

0 commit comments

Comments
 (0)