Skip to content

Commit b198728

Browse files
niklubMikhail MaluykdepppMichael Malyuknik
authored
feat: HUMSIG-37: Enhancing ML backend connection experience (#5546)
* Updated implementation of the ML Backend experience. Docs & Test to follow * Updating the experience a bit more based on the feedback. Plus updating tests * minor updates on the experience * removing debug info, uncommenting bits * fixing small issue to make it backward compatible with the previous project setting * Fix errors, code cleanup, fix ruff * Downgrade testing-library, reformat with linters * Remove excessive calls in task api * Fix frontend script * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8192397383 * fmt * fix sdk version & change api /predict/test * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8201028879 * try update lock * [submodules] Copy src HumanSignal/dm2 Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8207748036 * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8207796735 * Address review comments with stylistic changes, remove unusable code, change some writings * Additional changes * Add svg icon file * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8217959028 * Fix ml/predict/test api * Running a formatter/lint on the code * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8237179655 * Running a formatter/lint on the code * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8237277894 * Running a formatter/lint on the code * Running a formatter/lint on the code * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8237418085 * Handle security on password input * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8238405264 * Display default password when not specified * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8240219216 * Reduce number of calls due to project id cache * ci: Build frontend Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/8240697575 --------- Co-authored-by: Mikhail Maluyk <mikhail.maluyk@gmail.com> Co-authored-by: Michael Malyuk <28912+deppp@users.noreply.github.com> Co-authored-by: Michael Malyuk <michaelmalyuk@Michaels-MacBook-Air-2.local> Co-authored-by: nik <nik@heartex.net> Co-authored-by: robot-ci-heartex <robot-ci-heartex@users.noreply.github.com> Co-authored-by: Jo Booth <jo.m.booth@gmail.com> Co-authored-by: Brandon Martel <brandonmartel@gmail.com>
1 parent 537d30d commit b198728

92 files changed

Lines changed: 4032 additions & 2615 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.pre-commit-dev.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ repos:
99
rev: v0.9.1
1010
hooks:
1111
- id: blue
12+
args: [ --verbose ]

label_studio/core/all_urls.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,12 @@
946946
"module": "ml.api.MLBackendTrainAPI",
947947
"name": "ml:api:ml-train",
948948
"decorators": ""
949+
},
950+
{
951+
"url": "/api/ml/<int:pk>/predict/test",
952+
"module": "ml.api.MLBackendPredictAPI",
953+
"name": "ml:api:ml-predict-test",
954+
"decorators": ""
949955
},
950956
{
951957
"url": "/api/ml/<int:pk>/interactive-annotating",

label_studio/data_manager/actions/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from core.permissions import AllPermissions
77
from core.redis import start_job_async_or_sync
88
from core.utils.common import load_func
9-
from data_manager.functions import evaluate_predictions
9+
from data_manager.functions import retrieve_predictions
1010
from django.conf import settings
1111
from projects.models import Project
1212
from tasks.functions import update_tasks_counters
@@ -24,7 +24,7 @@ def retrieve_tasks_predictions(project, queryset, **kwargs):
2424
:param project: project instance
2525
:param queryset: filtered tasks db queryset
2626
"""
27-
evaluate_predictions(queryset)
27+
retrieve_predictions(queryset)
2828
return {'processed_items': queryset.count(), 'detail': 'Retrieved ' + str(queryset.count()) + ' predictions'}
2929

3030

@@ -138,6 +138,7 @@ def async_project_summary_recalculation(tasks_ids_list, project_id):
138138
'title': 'Retrieve Predictions',
139139
'order': 90,
140140
'dialog': {
141+
'modal_title': 'Retrieve Predictions',
141142
'text': 'Send the selected tasks to all ML backends connected to the project.'
142143
'This operation might be abruptly interrupted due to a timeout. '
143144
'The recommended way to get predictions is to update tasks using the Label Studio API.'

label_studio/data_manager/actions/next_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def next_task(project, queryset, **kwargs):
3232
# serialize task
3333
context = {'request': request, 'project': project, 'resolve_uri': True, 'annotations': False}
3434
serializer = NextTaskSerializer(next_task, context=context)
35+
3536
response = serializer.data
3637
response['queue'] = queue_info
3738
return response

label_studio/data_manager/actions/predictions_to_annotations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def predictions_to_annotations_form(user, project):
8080
{
8181
'type': 'select',
8282
'name': 'model_version',
83-
'label': 'Choose a model',
83+
'label': 'Choose predictions',
8484
'options': versions,
8585
}
8686
],
@@ -95,8 +95,10 @@ def predictions_to_annotations_form(user, project):
9595
'title': 'Create Annotations From Predictions',
9696
'order': 91,
9797
'dialog': {
98-
'text': 'This action will create new annotations from predictions with the selected model version '
99-
'for each selected task.',
98+
'modal_title': 'Create Annotations From Predictions',
99+
'text': 'Create annotations from predictions using selected predictions set '
100+
'for each selected task.'
101+
'Your account will be assigned as an owner to those annotations. ',
100102
'type': 'confirm',
101103
'form': predictions_to_annotations_form,
102104
},

label_studio/data_manager/api.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from core.utils.common import int_from_request, load_func
99
from core.utils.params import bool_from_request
1010
from data_manager.actions import get_all_actions, perform_action
11-
from data_manager.functions import evaluate_predictions, get_prepare_params, get_prepared_queryset
11+
from data_manager.functions import get_prepare_params, get_prepared_queryset
1212
from data_manager.managers import get_fields_for_evaluation
1313
from data_manager.models import View
1414
from data_manager.serializers import DataManagerTaskSerializer, ViewResetSerializer, ViewSerializer
@@ -255,11 +255,15 @@ def get(self, request):
255255
# keep ids ordering
256256
page = [tasks_by_ids[_id] for _id in ids]
257257

258+
# TODO MM TODO this needs a discussion, because I'd expect
259+
# people to retrieve manually instead on DM load, plus it
260+
# will slow down initial DM load
261+
258262
# retrieve ML predictions if tasks don't have them
259-
if not review and project.evaluate_predictions_automatically:
260-
tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
261-
evaluate_predictions(tasks_for_predictions)
262-
[tasks_by_ids[_id].refresh_from_db() for _id in ids]
263+
# if not review and project.retrieve_predictions_automatically:
264+
# tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
265+
# retrieve_predictions(tasks_for_predictions)
266+
# [tasks_by_ids[_id].refresh_from_db() for _id in ids]
263267

264268
if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
265269
serializer = self.task_serializer_class(
@@ -271,13 +275,18 @@ def get(self, request):
271275
else:
272276
serializer = self.task_serializer_class(page, many=True, context=context)
273277
return self.get_paginated_response(serializer.data)
278+
279+
# TODO
274280
# all tasks
275-
if project.evaluate_predictions_automatically:
276-
evaluate_predictions(queryset.filter(predictions__isnull=True))
281+
# if project.retrieve_predictions_automatically:
282+
# retrieve_predictions(queryset.filter(predictions__isnull=True))
283+
277284
queryset = Task.prepared.annotate_queryset(
278285
queryset, fields_for_evaluation=fields_for_evaluation, all_fields=all_fields, request=request
279286
)
287+
280288
serializer = self.task_serializer_class(queryset, many=True, context=context)
289+
281290
return Response(serializer.data)
282291

283292

@@ -377,6 +386,7 @@ def post(self, request):
377386
# perform action and return the result dict
378387
kwargs = {'request': request} # pass advanced params to actions
379388
result = perform_action(action_id, project, queryset, request.user, **kwargs)
389+
380390
code = result.pop('response_code', 200)
381391

382392
return Response(result, status=code)

label_studio/data_manager/functions.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -314,16 +314,23 @@ def get_prepared_queryset(request, project):
314314
return queryset
315315

316316

317-
def evaluate_predictions(tasks):
318-
"""Call ML backend for prediction evaluation of the task queryset"""
317+
def retrieve_predictions(tasks, backend=None):
318+
"""Call ML backend to retrieve predictions with the task queryset as an input"""
319319
if not tasks:
320320
return
321321

322-
project = tasks[0].project
322+
if not backend:
323+
project = tasks[0].project
324+
backend = project.ml_backends.first()
323325

324-
for ml_backend in project.ml_backends.all():
325-
# tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
326-
ml_backend.predict_tasks(tasks)
326+
# IMPORTANT change here, ml_backends.all => ml_backends.first
327+
# we are using only one ML backend, not multiple
328+
if backend:
329+
return backend.predict_and_save(tasks=tasks)
330+
331+
# for ml_backend in project.ml_backends.first():
332+
# # tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
333+
# ml_backend.predict_and_save(tasks=tasks)
327334

328335

329336
def filters_ordering_selected_items_exist(data):

label_studio/ml/api.py

Lines changed: 104 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from core.feature_flags import flag_set
77
from core.permissions import ViewClassPermission, all_permissions
88
from django.conf import settings
9+
from django.http import Http404
910
from django.utils.decorators import method_decorator
1011
from django_filters.rest_framework import DjangoFilterBackend
1112
from drf_yasg.utils import swagger_auto_schema
@@ -76,10 +77,11 @@ class MLBackendListAPI(generics.ListCreateAPIView):
7677
def get_queryset(self):
7778
project_pk = self.request.query_params.get('project')
7879
project = generics.get_object_or_404(Project, pk=project_pk)
80+
7981
self.check_object_permissions(self.request, project)
80-
ml_backends = MLBackend.objects.filter(project_id=project.id)
81-
for mlb in ml_backends:
82-
mlb.update_state()
82+
83+
ml_backends = project.update_ml_backends_state()
84+
8385
return ml_backends
8486

8587
def perform_create(self, serializer):
@@ -202,6 +204,58 @@ def post(self, request, *args, **kwargs):
202204
return Response(status=status.HTTP_200_OK)
203205

204206

207+
@method_decorator(
208+
name='post',
209+
decorator=swagger_auto_schema(
210+
tags=['Machine Learning'],
211+
operation_summary='Predict',
212+
operation_description="""
213+
After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data
214+
""",
215+
manual_parameters=[
216+
openapi.Parameter(
217+
name='id',
218+
type=openapi.TYPE_INTEGER,
219+
in_=openapi.IN_PATH,
220+
description='A unique integer value identifying this ML backend.',
221+
),
222+
],
223+
responses={
224+
200: openapi.Response(title='Predicting OK', description='Predicting has successfully started.'),
225+
500: openapi.Response(
226+
description='Predicting error',
227+
schema=openapi.Schema(
228+
title='Error message',
229+
description='Error message',
230+
type=openapi.TYPE_STRING,
231+
example='Server responded with an error.',
232+
),
233+
),
234+
},
235+
),
236+
)
237+
class MLBackendPredictTestAPI(APIView):
238+
serializer_class = MLBackendSerializer
239+
permission_required = all_permissions.projects_change
240+
241+
def post(self, request, *args, **kwargs):
242+
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
243+
self.check_object_permissions(self.request, ml_backend)
244+
245+
random = request.query_params.get('random', False)
246+
if random:
247+
task = Task.get_random(project=ml_backend.project)
248+
if not task:
249+
raise Http404
250+
251+
kwargs = ml_backend._predict(task)
252+
return Response(**kwargs)
253+
254+
# TODO this needs to be implemented and needs to have a specific task param
255+
ml_backend.predict()
256+
return Response(status=status.HTTP_200_OK)
257+
258+
205259
@method_decorator(
206260
name='post',
207261
decorator=swagger_auto_schema(
@@ -227,29 +281,64 @@ def post(self, request, *args, **kwargs):
227281
),
228282
)
229283
class MLBackendInteractiveAnnotating(APIView):
284+
""" """
230285

231286
permission_required = all_permissions.tasks_view
232287

288+
def _error_response(self, message, log_function=logger.info):
289+
""" """
290+
log_function(message)
291+
return Response({'errors': [message]}, status=status.HTTP_200_OK)
292+
293+
def _get_task(self, ml_backend, validated_data):
294+
""" """
295+
return generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
296+
297+
def _get_credentials(self, request, context, project):
298+
""" """
299+
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
300+
context.update(
301+
project_credentials_login=project.task_data_login,
302+
project_credentials_password=project.task_data_password,
303+
)
304+
return context
305+
306+
def _get_ml_results(self, ml_api_result):
307+
""" """
308+
results = ml_api_result.response.get('results', [None])
309+
if isinstance(results, list) and len(results) >= 1:
310+
return results[0]
311+
312+
return None
313+
233314
def post(self, request, *args, **kwargs):
315+
""" """
234316
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
235-
self.check_object_permissions(self.request, ml_backend)
317+
self.check_object_permissions(request, ml_backend)
318+
236319
serializer = MLInteractiveAnnotatingRequest(data=request.data)
237320
serializer.is_valid(raise_exception=True)
238-
validated_data = serializer.validated_data
239321

240-
task = generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
241-
context = validated_data.get('context')
322+
task = self._get_task(ml_backend, serializer.validated_data)
323+
context = self._get_credentials(request, serializer.validated_data.get('context', {}), task.project)
242324

243-
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
244-
context['project_credentials_login'] = task.project.task_data_login
245-
context['project_credentials_password'] = task.project.task_data_password
325+
ml_api_result = ml_backend.interactive_annotating(task, context, user=self.request.user)
326+
327+
if ml_api_result.is_error:
328+
message = f'Prediction not created for project {self}: {ml_api_result.error_message}'
329+
return self._error_response(message)
330+
331+
if not isinstance(ml_api_result.response, dict) or 'results' not in ml_api_result.response:
332+
message = f'Incorrect response from ML service it must be a dict and contain "results" key: {ml_api_result.response}'
333+
return self._error_response(message)
334+
335+
ml_results = self._get_ml_results(ml_api_result)
246336

247-
result = ml_backend.interactive_annotating(task, context, user=request.user)
337+
if not ml_results:
338+
message = f'ML backend has to return a list with at least 1 annotation but it returned: {type(ml_results)}'
339+
return self._error_response(message, logger.warning)
248340

249-
return Response(
250-
result,
251-
status=status.HTTP_200_OK,
252-
)
341+
return Response({'data': ml_results}, status=status.HTTP_200_OK)
253342

254343

255344
@method_decorator(

0 commit comments

Comments
 (0)