|
1 | 1 | use sea_orm::*; |
2 | 2 | use serde_json; |
3 | 3 |
|
4 | | -use crate::entity::{conversation_summaries, conversations}; |
| 4 | +use crate::entity::{conversation_summaries, conversations, messages}; |
5 | 5 | use crate::error::{AQBotError, Result}; |
6 | 6 | use crate::types::{ |
7 | 7 | Conversation, ConversationSearchResult, ConversationSummary, UpdateConversationInput, |
@@ -30,6 +30,7 @@ fn conversation_from_entity(m: conversations::Model) -> Conversation { |
30 | 30 | is_archived: m.is_archived != 0, |
31 | 31 | context_compression: m.context_compression != 0, |
32 | 32 | category_id: m.category_id, |
| 33 | + parent_conversation_id: m.parent_conversation_id, |
33 | 34 | created_at: m.created_at, |
34 | 35 | updated_at: m.updated_at, |
35 | 36 | } |
@@ -170,6 +171,9 @@ pub async fn update_conversation( |
170 | 171 | if let Some(category_id) = input.category_id { |
171 | 172 | am.category_id = Set(category_id); |
172 | 173 | } |
| 174 | + if let Some(parent_conversation_id) = input.parent_conversation_id { |
| 175 | + am.parent_conversation_id = Set(parent_conversation_id); |
| 176 | + } |
173 | 177 | am.updated_at = Set(now); |
174 | 178 | am.update(db).await?; |
175 | 179 |
|
@@ -233,6 +237,146 @@ pub async fn delete_conversation(db: &DatabaseConnection, id: &str) -> Result<() |
233 | 237 | Ok(()) |
234 | 238 | } |
235 | 239 |
|
| 240 | +/// Branch a conversation: copy settings + messages up to `until_message_id`. |
| 241 | +/// If `as_child` is true, the new conversation is nested under the source (or its parent). |
| 242 | +pub async fn branch_conversation( |
| 243 | + db: &DatabaseConnection, |
| 244 | + conversation_id: &str, |
| 245 | + until_message_id: &str, |
| 246 | + as_child: bool, |
| 247 | + custom_title: Option<&str>, |
| 248 | +) -> Result<Conversation> { |
| 249 | + // 1. Load source conversation |
| 250 | + let source = conversations::Entity::find_by_id(conversation_id) |
| 251 | + .one(db) |
| 252 | + .await? |
| 253 | + .ok_or_else(|| AQBotError::NotFound(format!("Conversation {}", conversation_id)))?; |
| 254 | + |
| 255 | + // 2. Load all active messages ordered by created_at |
| 256 | + let all_msgs = messages::Entity::find() |
| 257 | + .filter(messages::Column::ConversationId.eq(conversation_id)) |
| 258 | + .filter(messages::Column::IsActive.eq(1)) |
| 259 | + .order_by_asc(messages::Column::CreatedAt) |
| 260 | + .all(db) |
| 261 | + .await?; |
| 262 | + |
| 263 | + // 3. Find the target message index |
| 264 | + let target_idx = all_msgs |
| 265 | + .iter() |
| 266 | + .position(|m| m.id == until_message_id) |
| 267 | + .ok_or_else(|| { |
| 268 | + AQBotError::NotFound(format!("Message {} in conversation", until_message_id)) |
| 269 | + })?; |
| 270 | + |
| 271 | + // 4. Slice messages up to (and including) the target |
| 272 | + let candidate_msgs = &all_msgs[..=target_idx]; |
| 273 | + |
| 274 | + // 5. Find last context-clear marker to determine effective start |
| 275 | + let start_idx = candidate_msgs |
| 276 | + .iter() |
| 277 | + .rposition(|m| { |
| 278 | + m.role == "system" |
| 279 | + && (m.content == "<!-- context-clear -->" |
| 280 | + || m.content == "<!-- context-compressed -->") |
| 281 | + }) |
| 282 | + .map(|idx| idx + 1) // skip the marker itself |
| 283 | + .unwrap_or(0); |
| 284 | + |
| 285 | + let effective_msgs = &candidate_msgs[start_idx..]; |
| 286 | + |
| 287 | + // 6. Create new conversation with copied settings |
| 288 | + let new_id = gen_id(); |
| 289 | + let now = now_ts(); |
| 290 | + let branch_title = custom_title |
| 291 | + .map(|t| t.to_string()) |
| 292 | + .unwrap_or_else(|| source.title.clone()); |
| 293 | + |
| 294 | + // Determine parent_conversation_id |
| 295 | + let parent_id = if as_child { |
| 296 | + // If source already has a parent, new branch is a sibling (same parent) |
| 297 | + // Otherwise, source becomes the parent |
| 298 | + Some( |
| 299 | + source |
| 300 | + .parent_conversation_id |
| 301 | + .clone() |
| 302 | + .unwrap_or_else(|| source.id.clone()), |
| 303 | + ) |
| 304 | + } else { |
| 305 | + None |
| 306 | + }; |
| 307 | + |
| 308 | + conversations::ActiveModel { |
| 309 | + id: Set(new_id.clone()), |
| 310 | + title: Set(branch_title), |
| 311 | + model_id: Set(source.model_id.clone()), |
| 312 | + provider_id: Set(source.provider_id.clone()), |
| 313 | + system_prompt: Set(source.system_prompt.clone()), |
| 314 | + temperature: Set(source.temperature), |
| 315 | + max_tokens: Set(source.max_tokens), |
| 316 | + top_p: Set(source.top_p), |
| 317 | + frequency_penalty: Set(source.frequency_penalty), |
| 318 | + search_enabled: Set(source.search_enabled), |
| 319 | + search_provider_id: Set(source.search_provider_id.clone()), |
| 320 | + thinking_budget: Set(source.thinking_budget), |
| 321 | + enabled_mcp_server_ids: Set(source.enabled_mcp_server_ids.clone()), |
| 322 | + enabled_knowledge_base_ids: Set(source.enabled_knowledge_base_ids.clone()), |
| 323 | + enabled_memory_namespace_ids: Set(source.enabled_memory_namespace_ids.clone()), |
| 324 | + message_count: Set(effective_msgs.len() as i32), |
| 325 | + is_pinned: Set(0), |
| 326 | + is_archived: Set(0), |
| 327 | + context_compression: Set(source.context_compression), |
| 328 | + category_id: Set(source.category_id.clone()), |
| 329 | + parent_conversation_id: Set(parent_id), |
| 330 | + research_mode: Set(source.research_mode), |
| 331 | + created_at: Set(now), |
| 332 | + updated_at: Set(now), |
| 333 | + ..Default::default() |
| 334 | + } |
| 335 | + .insert(db) |
| 336 | + .await?; |
| 337 | + |
| 338 | + // 7. Copy messages — assign new IDs and remap parent_message_id references |
| 339 | + let mut id_map = std::collections::HashMap::new(); |
| 340 | + for msg in effective_msgs { |
| 341 | + let new_msg_id = gen_id(); |
| 342 | + id_map.insert(msg.id.clone(), new_msg_id.clone()); |
| 343 | + |
| 344 | + let new_parent = msg |
| 345 | + .parent_message_id |
| 346 | + .as_ref() |
| 347 | + .and_then(|pid| id_map.get(pid)) |
| 348 | + .cloned(); |
| 349 | + |
| 350 | + messages::ActiveModel { |
| 351 | + id: Set(new_msg_id), |
| 352 | + conversation_id: Set(new_id.clone()), |
| 353 | + role: Set(msg.role.clone()), |
| 354 | + content: Set(msg.content.clone()), |
| 355 | + provider_id: Set(msg.provider_id.clone()), |
| 356 | + model_id: Set(msg.model_id.clone()), |
| 357 | + token_count: Set(msg.token_count), |
| 358 | + prompt_tokens: Set(msg.prompt_tokens), |
| 359 | + completion_tokens: Set(msg.completion_tokens), |
| 360 | + attachments: Set(msg.attachments.clone()), |
| 361 | + thinking: Set(msg.thinking.clone()), |
| 362 | + created_at: Set(msg.created_at), |
| 363 | + parent_message_id: Set(new_parent), |
| 364 | + version_index: Set(msg.version_index), |
| 365 | + is_active: Set(1), |
| 366 | + tool_calls_json: Set(msg.tool_calls_json.clone()), |
| 367 | + tool_call_id: Set(msg.tool_call_id.clone()), |
| 368 | + status: Set(msg.status.clone()), |
| 369 | + tokens_per_second: Set(msg.tokens_per_second), |
| 370 | + first_token_latency_ms: Set(msg.first_token_latency_ms), |
| 371 | + ..Default::default() |
| 372 | + } |
| 373 | + .insert(db) |
| 374 | + .await?; |
| 375 | + } |
| 376 | + |
| 377 | + get_conversation(db, &new_id).await |
| 378 | +} |
| 379 | + |
236 | 380 | pub async fn search_conversations( |
237 | 381 | db: &DatabaseConnection, |
238 | 382 | query: &str, |
|
0 commit comments