Skip to content

Commit ef614db

Browse files
committed
fix(LAB-4269): LLM project need annotations as the jsonResponseUrl is not compute for those project
1 parent 1753623 commit ef614db

4 files changed

Lines changed: 173 additions & 13 deletions

File tree

src/kili/adapters/kili_api_gateway/asset/operations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,9 @@ def get_assets_query(fragment: str) -> str:
3030
external_ids: filterExistingAssets(projectID: $projectID, externalIDs: $externalIDs)
3131
}
3232
"""
33+
34+
GQL_COUNT_ASSET_ANNOTATIONS = """
35+
query countAssetAnnotations($where: AssetWhere!) {
36+
data: countAssetAnnotations(where: $where)
37+
}
38+
"""

src/kili/adapters/kili_api_gateway/asset/operations_mixin.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
from collections.abc import Generator
44

5+
from kili_formats.tool.annotations_to_json_response import (
6+
AnnotationsToJsonResponseConverter,
7+
)
8+
59
from kili.adapters.kili_api_gateway.asset.formatters import (
610
load_asset_json_fields,
711
)
812
from kili.adapters.kili_api_gateway.asset.mappers import asset_where_mapper
913
from kili.adapters.kili_api_gateway.asset.operations import (
14+
GQL_COUNT_ASSET_ANNOTATIONS,
1015
GQL_COUNT_ASSETS,
1116
GQL_CREATE_UPLOAD_BUCKET_SIGNED_URLS,
1217
GQL_FILTER_EXISTING_ASSETS,
@@ -18,11 +23,18 @@
1823
QueryOptions,
1924
fragment_builder,
2025
)
26+
from kili.adapters.kili_api_gateway.label.common import get_annotation_fragment
2127
from kili.adapters.kili_api_gateway.project.common import get_project
2228
from kili.core.graphql.operations.asset.mutations import GQL_SET_ASSET_CONSENSUS
2329
from kili.domain.asset import AssetFilters
2430
from kili.domain.types import ListOrTuple
2531

32+
# Threshold for batching based on number of annotations
33+
# This is used to determine whether to use a single batch or multiple batches
34+
# when fetching assets. If the number of annotations counted exceeds this threshold,
35+
# the asset fetch will be done in multiple smaller batches to avoid performance issues.
36+
THRESHOLD_FOR_BATCHING = 200
37+
2638

2739
class AssetOperationMixin(BaseOperationMixin):
2840
"""Mixin extending Kili API Gateway class with Assets related operations."""
@@ -66,30 +78,73 @@ def list_assets(
6678

6779
yield from assets_gen
6880

69-
def list_assets_split(
81+
def list_assets_split( # pylint: disable=too-many-branches
7082
self,
7183
filters: AssetFilters,
7284
fields: ListOrTuple[str],
7385
options: QueryOptions,
7486
project_info,
7587
) -> Generator[dict, None, None]:
7688
"""List assets with given options."""
89+
# For LLM projects, we need to fetch annotations and rebuild jsonResponse
90+
# because LLM projects don't have jsonResponseUrl
91+
is_llm_project = project_info["inputType"] in {
92+
"LLM_RLHF",
93+
"LLM_INSTR_FOLLOWING",
94+
"LLM_STATIC",
95+
}
96+
7797
assets_batch_max_amount = 10 if project_info["inputType"] == "VIDEO" else 50
7898
batch_size_to_use = min(options.batch_size, assets_batch_max_amount)
7999

80-
options = QueryOptions(options.disable_tqdm, options.first, options.skip, batch_size_to_use)
100+
# For LLM projects fetching annotations, adjust batch size based on annotation count
101+
if is_llm_project and (
102+
"labels.jsonResponse" in fields or "latestLabel.jsonResponse" in fields
103+
):
104+
nb_annotations = self.count_assets_annotations(filters)
105+
batch_size = (
106+
1
107+
if nb_annotations / batch_size_to_use > THRESHOLD_FOR_BATCHING
108+
else batch_size_to_use
109+
)
110+
else:
111+
batch_size = batch_size_to_use
112+
113+
options = QueryOptions(options.disable_tqdm, options.first, options.skip, batch_size)
114+
115+
requested_labels_json_response = "labels.jsonResponse" in fields
116+
requested_latest_label_json_response = "latestLabel.jsonResponse" in fields
117+
needs_json_response = requested_labels_json_response or requested_latest_label_json_response
81118

82119
required_fields = {"content", "jsonContent", "resolution.width", "resolution.height"}
83-
if "labels.jsonResponse" in fields:
84-
required_fields.add("labels.jsonResponseUrl")
85-
if "latestLabel.jsonResponse" in fields:
86-
required_fields.add("latestLabel.jsonResponseUrl")
87120
fields = list(fields)
121+
122+
static_fragments = {}
123+
if is_llm_project and needs_json_response:
124+
# For LLM projects: fetch annotations and rebuild jsonResponse client-side
125+
inner_annotation_fragment = get_annotation_fragment()
126+
annotation_fragment = f"""
127+
annotations {{
128+
{inner_annotation_fragment}
129+
}}
130+
"""
131+
static_fragments = {"labels": annotation_fragment, "latestLabel": annotation_fragment}
132+
133+
fields = list(fields)
134+
for field in required_fields:
135+
if field not in fields:
136+
fields.append(field)
137+
else:
138+
if requested_labels_json_response:
139+
required_fields.add("labels.jsonResponseUrl")
140+
if requested_latest_label_json_response:
141+
required_fields.add("latestLabel.jsonResponseUrl")
142+
88143
for field in required_fields:
89144
if field not in fields:
90145
fields.append(field)
91146

92-
fragment = fragment_builder(fields)
147+
fragment = fragment_builder(fields, static_fragments if static_fragments else None)
93148
query = get_assets_query(fragment)
94149
where = asset_where_mapper(filters)
95150
assets_gen = PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
@@ -99,7 +154,26 @@ def list_assets_split(
99154
load_asset_json_fields(asset, fields, self.http_client) for asset in assets_gen
100155
)
101156

102-
yield from assets_gen
157+
if is_llm_project and needs_json_response:
158+
# Rebuild jsonResponse from annotations for LLM projects
159+
converter = AnnotationsToJsonResponseConverter(
160+
json_interface=project_info["jsonInterface"],
161+
project_input_type=project_info["inputType"],
162+
)
163+
for asset in assets_gen:
164+
if requested_latest_label_json_response and asset.get("latestLabel"):
165+
converter.patch_label_json_response(
166+
asset, asset["latestLabel"], asset["latestLabel"]["annotations"]
167+
)
168+
asset["latestLabel"].pop("annotations", None)
169+
170+
if requested_labels_json_response:
171+
for label in asset.get("labels", []):
172+
converter.patch_label_json_response(asset, label, label["annotations"])
173+
label.pop("annotations", None)
174+
yield asset
175+
else:
176+
yield from assets_gen
103177

104178
def count_assets(self, filters: AssetFilters) -> int:
105179
"""Send a GraphQL request calling countIssues resolver."""
@@ -109,6 +183,14 @@ def count_assets(self, filters: AssetFilters) -> int:
109183
count: int = count_result["data"]
110184
return count
111185

186+
def count_assets_annotations(self, filters: AssetFilters) -> int:
187+
"""Count the number of annotations for assets matching the filters."""
188+
where = asset_where_mapper(filters)
189+
payload = {"where": where}
190+
count_result = self.graphql_client.execute(GQL_COUNT_ASSET_ANNOTATIONS, payload)
191+
count: int = count_result["data"]
192+
return count
193+
112194
def create_upload_bucket_signed_urls(self, file_paths: list[str]) -> list[str]:
113195
"""Send a GraphQL request calling createUploadBucketSignedUrls resolver."""
114196
payload = {

src/kili/adapters/kili_api_gateway/label/common.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,42 @@
1010
from kili.exceptions import NotFound
1111

1212

13+
def get_annotation_fragment() -> str:
14+
"""Generate a basic annotation fragment for querying annotations.
15+
16+
This is used for LLM projects.
17+
"""
18+
return """
19+
__typename
20+
id
21+
job
22+
path
23+
labelId
24+
... on ClassificationAnnotation {
25+
annotationValue {
26+
categories
27+
}
28+
chatItemId
29+
}
30+
... on ComparisonAnnotation {
31+
annotationValue {
32+
choice {
33+
code
34+
firstId
35+
secondId
36+
}
37+
}
38+
chatItemId
39+
}
40+
... on TranscriptionAnnotation {
41+
annotationValue {
42+
text
43+
}
44+
chatItemId
45+
}
46+
"""
47+
48+
1349
def get_asset(
1450
graphql_client: GraphQLClient,
1551
http_client: HttpClient,

src/kili/adapters/kili_api_gateway/label/operations_mixin.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import Generator
55
from typing import Optional
66

7+
from kili_formats.tool.annotations_to_json_response import AnnotationsToJsonResponseConverter
8+
79
from kili.adapters.kili_api_gateway.base import BaseOperationMixin
810
from kili.adapters.kili_api_gateway.helpers.queries import (
911
PaginatedGraphQLQuery,
@@ -19,6 +21,7 @@
1921
from kili.domain.types import ListOrTuple
2022
from kili.utils.tqdm import tqdm
2123

24+
from .common import get_annotation_fragment
2225
from .formatters import load_label_json_fields
2326
from .mappers import append_label_data_mapper, append_to_labels_data_mapper, label_where_mapper
2427
from .operations import (
@@ -87,20 +90,53 @@ def list_labels_split(
8790
options.disable_tqdm, options.first, options.skip, min(options.batch_size, 20)
8891
)
8992

93+
# For LLM projects, we need to fetch annotations and rebuild jsonResponse
94+
# because LLM projects don't have jsonResponseUrl
95+
is_llm_project = project_info["inputType"] in {
96+
"LLM_RLHF",
97+
"LLM_INSTR_FOLLOWING",
98+
"LLM_STATIC",
99+
}
100+
needs_json_response = "jsonResponse" in fields
101+
90102
fields = list(fields)
91-
if "jsonResponse" in fields and "jsonResponseUrl" not in fields:
92-
fields.append("jsonResponseUrl")
93103

94-
fragment = fragment_builder(fields)
95-
query = get_labels_query(fragment)
104+
if is_llm_project and needs_json_response:
105+
# For LLM projects: fetch annotations and rebuild jsonResponse client-side
106+
inner_annotation_fragment = get_annotation_fragment()
107+
full_fragment = f"""
108+
{fragment_builder([f for f in fields if f not in {"jsonResponse", "jsonResponseUrl"}])}
109+
annotations {{
110+
{inner_annotation_fragment}
111+
}}
112+
"""
113+
else:
114+
if "jsonResponse" in fields and "jsonResponseUrl" not in fields:
115+
fields.append("jsonResponseUrl")
116+
full_fragment = fragment_builder(fields)
117+
118+
query = get_labels_query(full_fragment)
96119
where = label_where_mapper(filters)
97120
labels_gen = PaginatedGraphQLQuery(self.graphql_client).execute_query_from_paginated_call(
98121
query, where, options, "Retrieving labels", GQL_COUNT_LABELS
99122
)
100123
labels_gen = (
101124
load_label_json_fields(label, fields, self.http_client) for label in labels_gen
102125
)
103-
yield from labels_gen
126+
127+
if is_llm_project and needs_json_response:
128+
# Rebuild jsonResponse from annotations for LLM projects
129+
converter = AnnotationsToJsonResponseConverter(
130+
json_interface=project_info["jsonInterface"],
131+
project_input_type=project_info["inputType"],
132+
)
133+
for label in labels_gen:
134+
asset = None
135+
converter.patch_label_json_response(asset, label, label["annotations"])
136+
label.pop("annotations", None)
137+
yield label
138+
else:
139+
yield from labels_gen
104140

105141
def delete_labels(
106142
self, ids: ListOrTuple[LabelId], disable_tqdm: Optional[bool]

0 commit comments

Comments
 (0)