|
8 | 8 | CHAT_ITEMS_NEEDED_FIELDS = [ |
9 | 9 | "id", |
10 | 10 | "content", |
11 | | - "createdAt", |
12 | 11 | "modelId", |
13 | 12 | "parentId", |
14 | 13 | "role", |
@@ -123,42 +122,76 @@ def _init_round(self, context): |
123 | 122 |
|
124 | 123 | def _build_rounds(self, chat_items, annotations, json_interface): |
125 | 124 | """A round is composed of a prompt with n pre-prompts and n completions.""" |
126 | | - ordered_chat_items = sorted(chat_items, key=lambda x: x["createdAt"]) |
| 125 | + dict_chat_items = {} |
| 126 | + for chat_item in chat_items: |
| 127 | + if dict_chat_items.get(chat_item["parentId"]) is None: |
| 128 | + dict_chat_items[chat_item["parentId"]] = [] |
| 129 | + dict_chat_items[chat_item["parentId"]].append(chat_item) |
127 | 130 | rounds = [] |
| 131 | + parent_target = None |
| 132 | + has_children = True |
128 | 133 | current_round = self._init_round([]) |
129 | | - for chat_item in ordered_chat_items: |
130 | | - role = chat_item["role"].lower() if chat_item["role"] else None |
131 | | - if role == "user" or role == "system": |
132 | | - if current_round["prompt"] is not None: |
133 | | - rounds.append(current_round) |
134 | | - new_context = ( |
135 | | - current_round["context"] |
136 | | - + current_round["pre_prompts"] |
137 | | - + [ |
138 | | - current_round["prompt"], |
139 | | - self._get_round_winner( |
140 | | - current_round["completion"], |
141 | | - current_round["annotations"], |
142 | | - json_interface, |
143 | | - ), |
144 | | - ] |
145 | | - ) |
146 | | - current_round = self._init_round(new_context) |
147 | | - |
148 | | - if role == "user": |
149 | | - current_round["prompt"] = chat_item |
150 | | - elif role == "system": |
151 | | - current_round["pre_prompts"].append(chat_item) |
152 | | - elif role == "assistant": |
153 | | - current_round["completion"].append(chat_item) |
154 | | - else: |
155 | | - raise ValueError(f"Role {chat_item['role']} not supported") |
156 | | - current_round["annotations"] += [ |
157 | | - annotation |
158 | | - for annotation in annotations |
159 | | - if annotation["chatItemId"] == chat_item["id"] |
160 | | - ] |
161 | | - rounds.append(current_round) |
| 134 | + |
| 135 | + while has_children: |
| 136 | + nodes = dict_chat_items.get(parent_target) |
| 137 | + if nodes is None or len(nodes) == 0: |
| 138 | + has_children = False |
| 139 | + continue |
| 140 | + node = nodes[0] |
| 141 | + if node["role"].lower() == "system": |
| 142 | + current_round["pre_prompts"].append(node) |
| 143 | + parent_target = node["id"] |
| 144 | + current_round["annotations"] += [ |
| 145 | + annotation |
| 146 | + for annotation in annotations |
| 147 | + if annotation["chatItemId"] == node["id"] |
| 148 | + ] |
| 149 | + continue |
| 150 | + |
| 151 | + if node["role"].lower() == "user": |
| 152 | + current_round["prompt"] = node |
| 153 | + parent_target = node["id"] |
| 154 | + current_round["annotations"] += [ |
| 155 | + annotation |
| 156 | + for annotation in annotations |
| 157 | + if annotation["chatItemId"] == node["id"] |
| 158 | + ] |
| 159 | + continue |
| 160 | + |
| 161 | + if node["role"].lower() == "assistant": |
| 162 | + has_children = False |
| 163 | + if dict_chat_items.get(parent_target) is None: |
| 164 | + continue |
| 165 | + for chat_item in dict_chat_items[parent_target]: |
| 166 | + current_round["completion"].append(chat_item) |
| 167 | + current_round["annotations"] += [ |
| 168 | + annotation |
| 169 | + for annotation in annotations |
| 170 | + if annotation["chatItemId"] == chat_item["id"] |
| 171 | + ] |
| 172 | + if not has_children and dict_chat_items.get(chat_item["id"]) is not None: |
| 173 | + has_children = True |
| 174 | + parent_target = chat_item["id"] |
| 175 | + |
| 176 | + rounds.append(current_round) |
| 177 | + new_context = ( |
| 178 | + current_round["context"] |
| 179 | + + current_round["pre_prompts"] |
| 180 | + + [ |
| 181 | + current_round["prompt"], |
| 182 | + self._get_round_winner( |
| 183 | + current_round["completion"], |
| 184 | + current_round["annotations"], |
| 185 | + json_interface, |
| 186 | + ), |
| 187 | + ] |
| 188 | + ) |
| 189 | + current_round = self._init_round(new_context) |
| 190 | + continue |
| 191 | + |
| 192 | + raise ValueError(f"Role {node['role']} not supported") |
| 193 | + if current_round["prompt"] is not None: |
| 194 | + rounds.append(current_round) |
162 | 195 | return rounds |
163 | 196 |
|
164 | 197 |
|
|
0 commit comments