|
1 | 1 | """Skill management HTTP endpoints.""" |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import logging |
4 | 5 | import os |
5 | | -import re |
| 6 | +import threading |
6 | 7 | from typing import Any, Dict, List, Optional |
7 | 8 |
|
8 | 9 | from fastapi import APIRouter, HTTPException, Query, UploadFile, File, Form, Header |
9 | | -from starlette.responses import JSONResponse |
| 10 | +from starlette.responses import JSONResponse, StreamingResponse |
10 | 11 | from pydantic import BaseModel |
11 | 12 |
|
12 | 13 | from consts.exceptions import SkillException, UnauthorizedError |
13 | 14 | from services.skill_service import SkillService |
14 | 15 | from consts.model import SkillInstanceInfoRequest |
15 | | -from utils.auth_utils import get_current_user_id |
| 16 | +from utils.auth_utils import get_current_user_id, get_current_user_info |
| 17 | +from utils.prompt_template_utils import get_skill_creation_simple_prompt_template |
| 18 | +from nexent.core.agents.agent_model import ModelConfig |
| 19 | +from agents.skill_creation_agent import create_simple_skill_from_request |
| 20 | +from nexent.core.utils.observer import MessageObserver |
16 | 21 |
|
17 | 22 | logger = logging.getLogger(__name__) |
18 | 23 |
|
19 | 24 | router = APIRouter(prefix="/skills", tags=["skills"]) |
| 25 | +skill_creator_router = APIRouter(prefix="/skills", tags=["simple-skills"]) |
20 | 26 |
|
21 | 27 |
|
22 | 28 | class SkillCreateRequest(BaseModel): |
@@ -453,88 +459,143 @@ async def delete_skill( |
453 | 459 | raise HTTPException(status_code=500, detail="Internal server error") |
454 | 460 |
|
455 | 461 |
|
456 | | -@router.delete("/{skill_name}/files/{file_path:path}") |
457 | | -async def delete_skill_file( |
458 | | - skill_name: str, |
459 | | - file_path: str, |
460 | | - authorization: Optional[str] = Header(None) |
461 | | -) -> JSONResponse: |
462 | | - """Delete a specific file within a skill directory. |
| 462 | +class SkillCreateSimpleRequest(BaseModel): |
| 463 | + """Request model for interactive skill creation.""" |
| 464 | + user_request: str |
463 | 465 |
|
464 | | - Args: |
465 | | - skill_name: Name of the skill |
466 | | - file_path: Relative path to the file within the skill directory |
467 | | - """ |
468 | | - try: |
469 | | - _, _ = get_current_user_id(authorization) |
470 | | - service = SkillService() |
471 | 466 |
|
472 | | - # Validate skill_name so it cannot be used for path traversal |
473 | | - if not skill_name: |
474 | | - raise HTTPException(status_code=400, detail="Invalid skill name") |
475 | | - if os.sep in skill_name or "/" in skill_name or ".." in skill_name: |
476 | | - raise HTTPException(status_code=400, detail="Invalid skill name") |
477 | | - |
478 | | - # Read config to get temp_filename for validation |
479 | | - config_content = service.get_skill_file_content(skill_name, "config.yaml") |
480 | | - if config_content is None: |
481 | | - raise HTTPException(status_code=404, detail="Config file not found") |
482 | | - |
483 | | - # Parse config to get temp_filename |
484 | | - import yaml |
485 | | - config = yaml.safe_load(config_content) |
486 | | - temp_filename = config.get("temp_filename", "") |
487 | | - |
488 | | - # Get the base directory for the skill |
489 | | - local_dir = os.path.join(service.skill_manager.local_skills_dir, skill_name) |
490 | | - |
491 | | - # Check for path traversal patterns in the raw file_path BEFORE any normalization |
492 | | - # This catches attempts like ../../etc/passwd or /etc/passwd |
493 | | - normalized_for_check = os.path.normpath(file_path) |
494 | | - if ".." in file_path or file_path.startswith("/") or (os.sep in file_path and file_path.startswith(os.sep)): |
495 | | - # Additional check: ensure the normalized path doesn't escape local_dir |
496 | | - abs_local_dir = os.path.abspath(local_dir) |
497 | | - abs_full_path = os.path.abspath(os.path.join(local_dir, normalized_for_check)) |
498 | | - try: |
499 | | - common = os.path.commonpath([abs_local_dir, abs_full_path]) |
500 | | - if common != abs_local_dir: |
501 | | - raise HTTPException(status_code=400, detail="Invalid file path: path traversal detected") |
502 | | - except ValueError: |
503 | | - raise HTTPException(status_code=400, detail="Invalid file path: path traversal detected") |
504 | | - |
505 | | - # Normalize the requested file path - use basename to strip directory components |
506 | | - safe_file_path = os.path.basename(os.path.normpath(file_path)) |
507 | | - |
508 | | - # Build full path and validate it stays within local_dir |
509 | | - full_path = os.path.normpath(os.path.join(local_dir, safe_file_path)) |
510 | | - abs_local_dir = os.path.abspath(local_dir) |
511 | | - abs_full_path = os.path.abspath(full_path) |
512 | | - |
513 | | - # Check for path traversal: abs_full_path should be within abs_local_dir |
514 | | - try: |
515 | | - common = os.path.commonpath([abs_local_dir, abs_full_path]) |
516 | | - if common != abs_local_dir: |
517 | | - raise HTTPException(status_code=400, detail="Invalid file path: path traversal detected") |
518 | | - except ValueError: |
519 | | - # Different drives on Windows |
520 | | - raise HTTPException(status_code=400, detail="Invalid file path: path traversal detected") |
| 467 | +def _build_model_config_from_tenant(tenant_id: str) -> ModelConfig: |
| 468 | + """Build ModelConfig from tenant's quick-config LLM model.""" |
| 469 | + from utils.config_utils import tenant_config_manager, get_model_name_from_config |
| 470 | + from consts.const import MODEL_CONFIG_MAPPING |
521 | 471 |
|
522 | | - # Validate the filename matches temp_filename |
523 | | - if not temp_filename or safe_file_path != temp_filename: |
524 | | - raise HTTPException(status_code=400, detail="Can only delete temp_filename files") |
| 472 | + quick_config = tenant_config_manager.get_model_config( |
| 473 | + key=MODEL_CONFIG_MAPPING["llm"], |
| 474 | + tenant_id=tenant_id |
| 475 | + ) |
| 476 | + if not quick_config: |
| 477 | + raise ValueError("No LLM model configured for tenant") |
525 | 478 |
|
526 | | - # Check if file exists |
527 | | - if not os.path.exists(full_path): |
528 | | - raise HTTPException(status_code=404, detail=f"File not found: {safe_file_path}") |
| 479 | + return ModelConfig( |
| 480 | + cite_name=quick_config.get("display_name", "default"), |
| 481 | + api_key=quick_config.get("api_key", ""), |
| 482 | + model_name=get_model_name_from_config(quick_config), |
| 483 | + url=quick_config.get("base_url", ""), |
| 484 | + temperature=0.1, |
| 485 | + top_p=0.95, |
| 486 | + ssl_verify=True, |
| 487 | + model_factory=quick_config.get("model_factory") |
| 488 | + ) |
529 | 489 |
|
530 | | - os.remove(full_path) |
531 | | - logger.info(f"Deleted skill file: {full_path}") |
532 | 490 |
|
533 | | - return JSONResponse(content={"message": f"File {safe_file_path} deleted successfully"}) |
534 | | - except UnauthorizedError as e: |
535 | | - raise HTTPException(status_code=401, detail=str(e)) |
536 | | - except HTTPException: |
537 | | - raise |
538 | | - except Exception as e: |
539 | | - logger.error(f"Error deleting skill file {skill_name}/{file_path}: {e}") |
540 | | - raise HTTPException(status_code=500, detail=str(e)) |
| 491 | +@skill_creator_router.post("/create-simple") |
| 492 | +async def create_simple_skill( |
| 493 | + request: SkillCreateSimpleRequest, |
| 494 | + authorization: Optional[str] = Header(None) |
| 495 | +): |
| 496 | + """Create a simple skill interactively via LLM agent. |
| 497 | +
|
| 498 | + Loads the skill_creation_simple prompt template, runs an internal agent |
| 499 | + with WriteSkillFileTool and ReadSkillMdTool, extracts the <SKILL> block |
| 500 | + from the final answer, and streams step progress and token content via SSE. |
| 501 | +
|
| 502 | + Yields SSE events: |
| 503 | + - step_count: Current agent step number |
| 504 | + - skill_content: Token-level content (thinking, code, deep_thinking, tool output) |
| 505 | + - final_answer: Complete skill content |
| 506 | + - done: Stream completion signal |
| 507 | + """ |
| 508 | + # Message types to stream as skill_content (token-level output) |
| 509 | + STREAMABLE_CONTENT_TYPES = frozenset([ |
| 510 | + "model_output_thinking", |
| 511 | + "model_output_code", |
| 512 | + "model_output_deep_thinking", |
| 513 | + "tool", |
| 514 | + "execution_logs", |
| 515 | + ]) |
| 516 | + |
| 517 | + async def generate(): |
| 518 | + import json |
| 519 | + try: |
| 520 | + _, tenant_id, language = get_current_user_info(authorization) |
| 521 | + |
| 522 | + template = get_skill_creation_simple_prompt_template(language) |
| 523 | + |
| 524 | + model_config = _build_model_config_from_tenant(tenant_id) |
| 525 | + observer = MessageObserver(lang=language) |
| 526 | + stop_event = threading.Event() |
| 527 | + |
| 528 | + # Get local_skills_dir from SkillManager |
| 529 | + skill_service = SkillService() |
| 530 | + local_skills_dir = skill_service.skill_manager.local_skills_dir or "" |
| 531 | + |
| 532 | + # Start skill creation in background thread |
| 533 | + def run_task(): |
| 534 | + create_simple_skill_from_request( |
| 535 | + system_prompt=template.get("system_prompt", ""), |
| 536 | + user_prompt=request.user_request, |
| 537 | + model_config_list=[model_config], |
| 538 | + observer=observer, |
| 539 | + stop_event=stop_event, |
| 540 | + local_skills_dir=local_skills_dir |
| 541 | + ) |
| 542 | + |
| 543 | + thread = threading.Thread(target=run_task) |
| 544 | + thread.start() |
| 545 | + |
| 546 | + # Poll observer for step_count and token content messages |
| 547 | + while thread.is_alive(): |
| 548 | + cached = observer.get_cached_message() |
| 549 | + for msg in cached: |
| 550 | + if isinstance(msg, str): |
| 551 | + try: |
| 552 | + data = json.loads(msg) |
| 553 | + msg_type = data.get("type", "") |
| 554 | + content = data.get("content", "") |
| 555 | + |
| 556 | + # Stream step progress |
| 557 | + if msg_type == "step_count": |
| 558 | + yield f"data: {json.dumps({'type': 'step_count', 'content': content}, ensure_ascii=False)}\n\n" |
| 559 | + # Stream token content (thinking, code, deep_thinking, tool output) |
| 560 | + elif msg_type in STREAMABLE_CONTENT_TYPES: |
| 561 | + yield f"data: {json.dumps({'type': 'skill_content', 'content': content}, ensure_ascii=False)}\n\n" |
| 562 | + # Stream final_answer content separately |
| 563 | + elif msg_type == "final_answer": |
| 564 | + yield f"data: {json.dumps({'type': 'final_answer', 'content': content}, ensure_ascii=False)}\n\n" |
| 565 | + except (json.JSONDecodeError, Exception): |
| 566 | + pass |
| 567 | + await asyncio.sleep(0.1) |
| 568 | + |
| 569 | + thread.join() |
| 570 | + |
| 571 | + # Stream any remaining cached messages after thread completes |
| 572 | + remaining = observer.get_cached_message() |
| 573 | + for msg in remaining: |
| 574 | + if isinstance(msg, str): |
| 575 | + try: |
| 576 | + data = json.loads(msg) |
| 577 | + msg_type = data.get("type", "") |
| 578 | + content = data.get("content", "") |
| 579 | + |
| 580 | + if msg_type == "step_count": |
| 581 | + yield f"data: {json.dumps({'type': 'step_count', 'content': content}, ensure_ascii=False)}\n\n" |
| 582 | + elif msg_type in STREAMABLE_CONTENT_TYPES: |
| 583 | + yield f"data: {json.dumps({'type': 'skill_content', 'content': content}, ensure_ascii=False)}\n\n" |
| 584 | + elif msg_type == "final_answer": |
| 585 | + yield f"data: {json.dumps({'type': 'final_answer', 'content': content}, ensure_ascii=False)}\n\n" |
| 586 | + except (json.JSONDecodeError, Exception): |
| 587 | + pass |
| 588 | + |
| 589 | + # Stream final answer content from observer |
| 590 | + final_result = observer.get_final_answer() |
| 591 | + if final_result: |
| 592 | + yield f"data: {json.dumps({'type': 'final_answer', 'content': final_result}, ensure_ascii=False)}\n\n" |
| 593 | + |
| 594 | + # Send done signal |
| 595 | + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" |
| 596 | + |
| 597 | + except Exception as e: |
| 598 | + logger.error(f"Error in create_simple_skill stream: {e}") |
| 599 | + yield f"data: {json.dumps({'type': 'error', 'message': str(e)}, ensure_ascii=False)}\n\n" |
| 600 | + |
| 601 | + return StreamingResponse(generate(), media_type="text/event-stream") |
0 commit comments