22
33from collections .abc import Generator
44
5+ from kili_formats .tool .annotations_to_json_response import (
6+ AnnotationsToJsonResponseConverter ,
7+ )
8+
59from kili .adapters .kili_api_gateway .asset .formatters import (
610 load_asset_json_fields ,
711)
812from kili .adapters .kili_api_gateway .asset .mappers import asset_where_mapper
913from kili .adapters .kili_api_gateway .asset .operations import (
14+ GQL_COUNT_ASSET_ANNOTATIONS ,
1015 GQL_COUNT_ASSETS ,
1116 GQL_CREATE_UPLOAD_BUCKET_SIGNED_URLS ,
1217 GQL_FILTER_EXISTING_ASSETS ,
1823 QueryOptions ,
1924 fragment_builder ,
2025)
26+ from kili .adapters .kili_api_gateway .label .common import get_annotation_fragment
2127from kili .adapters .kili_api_gateway .project .common import get_project
2228from kili .core .graphql .operations .asset .mutations import GQL_SET_ASSET_CONSENSUS
2329from kili .domain .asset import AssetFilters
2430from kili .domain .types import ListOrTuple
2531
32+ # Threshold for batching based on number of annotations
33+ # This is used to determine whether to use a single batch or multiple batches
34+ # when fetching assets. If the number of annotations counted exceeds this threshold,
35+ # the asset fetch will be done in multiple smaller batches to avoid performance issues.
36+ THRESHOLD_FOR_BATCHING = 200
37+
2638
2739class AssetOperationMixin (BaseOperationMixin ):
2840 """Mixin extending Kili API Gateway class with Assets related operations."""
@@ -66,30 +78,73 @@ def list_assets(
6678
6779 yield from assets_gen
6880
69- def list_assets_split (
81+ def list_assets_split ( # pylint: disable=too-many-branches
7082 self ,
7183 filters : AssetFilters ,
7284 fields : ListOrTuple [str ],
7385 options : QueryOptions ,
7486 project_info ,
7587 ) -> Generator [dict , None , None ]:
7688 """List assets with given options."""
89+ # For LLM projects, we need to fetch annotations and rebuild jsonResponse
90+ # because LLM projects don't have jsonResponseUrl
91+ is_llm_project = project_info ["inputType" ] in {
92+ "LLM_RLHF" ,
93+ "LLM_INSTR_FOLLOWING" ,
94+ "LLM_STATIC" ,
95+ }
96+
7797 assets_batch_max_amount = 10 if project_info ["inputType" ] == "VIDEO" else 50
7898 batch_size_to_use = min (options .batch_size , assets_batch_max_amount )
7999
80- options = QueryOptions (options .disable_tqdm , options .first , options .skip , batch_size_to_use )
100+ # For LLM projects fetching annotations, adjust batch size based on annotation count
101+ if is_llm_project and (
102+ "labels.jsonResponse" in fields or "latestLabel.jsonResponse" in fields
103+ ):
104+ nb_annotations = self .count_assets_annotations (filters )
105+ batch_size = (
106+ 1
107+ if nb_annotations / batch_size_to_use > THRESHOLD_FOR_BATCHING
108+ else batch_size_to_use
109+ )
110+ else :
111+ batch_size = batch_size_to_use
112+
113+ options = QueryOptions (options .disable_tqdm , options .first , options .skip , batch_size )
114+
115+ requested_labels_json_response = "labels.jsonResponse" in fields
116+ requested_latest_label_json_response = "latestLabel.jsonResponse" in fields
117+ needs_json_response = requested_labels_json_response or requested_latest_label_json_response
81118
82119 required_fields = {"content" , "jsonContent" , "resolution.width" , "resolution.height" }
83- if "labels.jsonResponse" in fields :
84- required_fields .add ("labels.jsonResponseUrl" )
85- if "latestLabel.jsonResponse" in fields :
86- required_fields .add ("latestLabel.jsonResponseUrl" )
87120 fields = list (fields )
121+
122+ static_fragments = {}
123+ if is_llm_project and needs_json_response :
124+ # For LLM projects: fetch annotations and rebuild jsonResponse client-side
125+ inner_annotation_fragment = get_annotation_fragment ()
126+ annotation_fragment = f"""
127+ annotations {{
128+ { inner_annotation_fragment }
129+ }}
130+ """
131+ static_fragments = {"labels" : annotation_fragment , "latestLabel" : annotation_fragment }
132+
133+ fields = list (fields )
134+ for field in required_fields :
135+ if field not in fields :
136+ fields .append (field )
137+ else :
138+ if requested_labels_json_response :
139+ required_fields .add ("labels.jsonResponseUrl" )
140+ if requested_latest_label_json_response :
141+ required_fields .add ("latestLabel.jsonResponseUrl" )
142+
88143 for field in required_fields :
89144 if field not in fields :
90145 fields .append (field )
91146
92- fragment = fragment_builder (fields )
147+ fragment = fragment_builder (fields , static_fragments if static_fragments else None )
93148 query = get_assets_query (fragment )
94149 where = asset_where_mapper (filters )
95150 assets_gen = PaginatedGraphQLQuery (self .graphql_client ).execute_query_from_paginated_call (
@@ -99,7 +154,26 @@ def list_assets_split(
99154 load_asset_json_fields (asset , fields , self .http_client ) for asset in assets_gen
100155 )
101156
102- yield from assets_gen
157+ if is_llm_project and needs_json_response :
158+ # Rebuild jsonResponse from annotations for LLM projects
159+ converter = AnnotationsToJsonResponseConverter (
160+ json_interface = project_info ["jsonInterface" ],
161+ project_input_type = project_info ["inputType" ],
162+ )
163+ for asset in assets_gen :
164+ if requested_latest_label_json_response and asset .get ("latestLabel" ):
165+ converter .patch_label_json_response (
166+ asset , asset ["latestLabel" ], asset ["latestLabel" ]["annotations" ]
167+ )
168+ asset ["latestLabel" ].pop ("annotations" , None )
169+
170+ if requested_labels_json_response :
171+ for label in asset .get ("labels" , []):
172+ converter .patch_label_json_response (asset , label , label ["annotations" ])
173+ label .pop ("annotations" , None )
174+ yield asset
175+ else :
176+ yield from assets_gen
103177
104178 def count_assets (self , filters : AssetFilters ) -> int :
105179 """Send a GraphQL request calling countIssues resolver."""
@@ -109,6 +183,14 @@ def count_assets(self, filters: AssetFilters) -> int:
109183 count : int = count_result ["data" ]
110184 return count
111185
186+ def count_assets_annotations (self , filters : AssetFilters ) -> int :
187+ """Count the number of annotations for assets matching the filters."""
188+ where = asset_where_mapper (filters )
189+ payload = {"where" : where }
190+ count_result = self .graphql_client .execute (GQL_COUNT_ASSET_ANNOTATIONS , payload )
191+ count : int = count_result ["data" ]
192+ return count
193+
112194 def create_upload_bucket_signed_urls (self , file_paths : list [str ]) -> list [str ]:
113195 """Send a GraphQL request calling createUploadBucketSignedUrls resolver."""
114196 payload = {
0 commit comments