|
2 | 2 |
|
3 | 3 | from typing import Dict, List, Optional, Union |
4 | 4 |
|
| 5 | +from kili.adapters.kili_api_gateway.helpers.queries import QueryOptions |
5 | 6 | from kili.adapters.kili_api_gateway.kili_api_gateway import KiliAPIGateway |
6 | 7 | from kili.domain.asset.asset import AssetFilters |
7 | 8 | from kili.domain.project import ProjectId |
8 | 9 |
|
9 | 10 | from .dynamic import LLMDynamicExporter |
10 | 11 | from .static import LLMStaticExporter |
11 | 12 |
|
| 13 | +CHAT_ITEMS_NEEDED_FIELDS = [ |
| 14 | + "id", |
| 15 | + "content", |
| 16 | + "createdAt", |
| 17 | + "modelId", |
| 18 | + "parentId", |
| 19 | + "role", |
| 20 | +] |
| 21 | + |
| 22 | +LABELS_NEEDED_FIELDS = [ |
| 23 | + "annotations.id", |
| 24 | + "author.id", |
| 25 | + "author.email", |
| 26 | + "author.firstname", |
| 27 | + "author.lastname", |
| 28 | + *(f"chatItems.{field}" for field in CHAT_ITEMS_NEEDED_FIELDS), |
| 29 | + "createdAt", |
| 30 | + "id", |
| 31 | + "isLatestLabelForUser", |
| 32 | + "isSentBackToQueue", |
| 33 | + "jsonResponse", # This is needed to keep annotations |
| 34 | + "labelType", |
| 35 | + "modelName", |
| 36 | +] |
| 37 | + |
| 38 | +ASSET_DYNAMIC_NEEDED_FIELDS = [ |
| 39 | + "assetProjectModels.id", |
| 40 | + "assetProjectModels.configuration", |
| 41 | + "assetProjectModels.name", |
| 42 | + "content", |
| 43 | + "externalId", |
| 44 | + "jsonMetadata", |
| 45 | + *(f"labels.{field}" for field in LABELS_NEEDED_FIELDS), |
| 46 | + "status", |
| 47 | +] |
| 48 | + |
| 49 | +ASSET_STATIC_NEEDED_FIELDS = [ |
| 50 | + "content", |
| 51 | + "externalId", |
| 52 | + "jsonMetadata", |
| 53 | + "labels.jsonResponse", |
| 54 | + "labels.author.id", |
| 55 | + "labels.author.email", |
| 56 | + "labels.author.firstname", |
| 57 | + "labels.author.lastname", |
| 58 | + "labels.createdAt", |
| 59 | + "labels.isLatestLabelForUser", |
| 60 | + "labels.isSentBackToQueue", |
| 61 | + "labels.labelType", |
| 62 | + "labels.modelName", |
| 63 | + "status", |
| 64 | +] |
| 65 | + |
12 | 66 |
|
13 | 67 | def export( # pylint: disable=too-many-arguments, too-many-locals |
14 | 68 | kili_api_gateway: KiliAPIGateway, |
15 | 69 | project_id: ProjectId, |
16 | 70 | asset_filter: AssetFilters, |
17 | 71 | disable_tqdm: Optional[bool], |
| 72 | + include_sent_back_labels: Optional[bool], |
18 | 73 | ) -> Optional[List[Dict[str, Union[List[str], str]]]]: |
19 | 74 | """Export the selected assets with their labels into the required format, and save it into a file archive.""" |
20 | 75 | project = kili_api_gateway.get_project(project_id, ["id", "inputType", "jsonInterface"]) |
21 | 76 | input_type = project["inputType"] |
22 | 77 |
|
| 78 | + fields = get_fields_to_fetch(input_type) |
| 79 | + asset_filter.status_in = ["LABELED", "REVIEWED", "TO_REVIEW"] |
| 80 | + assets = list( |
| 81 | + kili_api_gateway.list_assets(asset_filter, fields, QueryOptions(disable_tqdm=disable_tqdm)) |
| 82 | + ) |
| 83 | + cleaned_assets = preprocess_assets(assets, include_sent_back_labels or False) |
23 | 84 | if input_type == "LLM_RLHF": |
24 | | - return LLMStaticExporter(kili_api_gateway, disable_tqdm).export( |
25 | | - project_id, asset_filter, project["jsonInterface"] |
| 85 | + return LLMStaticExporter(kili_api_gateway).export( |
| 86 | + cleaned_assets, project_id, project["jsonInterface"] |
26 | 87 | ) |
27 | 88 | if input_type == "LLM_INSTR_FOLLOWING": |
28 | | - asset_filter.status_in = ["LABELED", "REVIEWED", "TO_REVIEW"] |
29 | | - return LLMDynamicExporter(kili_api_gateway, disable_tqdm).export( |
30 | | - asset_filter, project["jsonInterface"] |
31 | | - ) |
| 89 | + return LLMDynamicExporter(kili_api_gateway).export(cleaned_assets, project["jsonInterface"]) |
32 | 90 | raise ValueError(f'Project Input type "{input_type}" cannot be used for llm exports.') |
| 91 | + |
| 92 | + |
| 93 | +def get_fields_to_fetch(input_type: str) -> List[str]: |
| 94 | + """Return the fields to fetch depending on the export type.""" |
| 95 | + if input_type == "LLM_RLHF": |
| 96 | + return ASSET_STATIC_NEEDED_FIELDS |
| 97 | + return ASSET_DYNAMIC_NEEDED_FIELDS |
| 98 | + |
| 99 | + |
| 100 | +def preprocess_assets(assets: List[Dict], include_sent_back_labels: bool) -> List[Dict]: |
| 101 | + """Format labels in the requested format, and filter out autosave labels.""" |
| 102 | + assets_in_format = [] |
| 103 | + for asset in assets: |
| 104 | + if "labels" in asset: |
| 105 | + labels_of_asset = [] |
| 106 | + for label in asset["labels"]: |
| 107 | + labels_of_asset.append(label) |
| 108 | + if not include_sent_back_labels: |
| 109 | + labels_of_asset = list( |
| 110 | + filter(lambda label: label["isSentBackToQueue"] is False, labels_of_asset) |
| 111 | + ) |
| 112 | + if len(labels_of_asset) > 0: |
| 113 | + asset["labels"] = labels_of_asset |
| 114 | + assets_in_format.append(asset) |
| 115 | + if "latestLabel" in asset: |
| 116 | + label = asset["latestLabel"] |
| 117 | + if label is not None: |
| 118 | + asset["latestLabel"] = label |
| 119 | + if include_sent_back_labels or asset["latestLabel"]["isSentBackToQueue"] is False: |
| 120 | + assets_in_format.append(asset) |
| 121 | + return assets_in_format |
0 commit comments