|
32 | 32 | import atexit |
33 | 33 | import json |
34 | 34 | import logging |
| 35 | +import re |
35 | 36 | import threading |
36 | | -from collections.abc import Coroutine, Iterable |
| 37 | +from collections.abc import Callable, Coroutine, Iterable |
37 | 38 | from typing import Any |
38 | 39 |
|
39 | 40 | from mcp import ClientSession, StdioServerParameters |
|
42 | 43 | from mcp.client.streamable_http import streamablehttp_client |
43 | 44 |
|
44 | 45 | from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT |
| 46 | +from data_designer.config.utils.image_helpers import ( |
| 47 | + decode_base64_image, |
| 48 | + detect_image_format, |
| 49 | + extract_base64_from_data_uri, |
| 50 | +) |
45 | 51 | from data_designer.engine.mcp.errors import MCPToolError |
46 | 52 | from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult |
47 | 53 |
|
48 | 54 | logger = logging.getLogger(__name__) |
| 55 | +_DATA_URI_MIME_TYPE_RE = re.compile(r"^data:(?P<mime_type>[^;]+);base64,") |
49 | 56 |
|
50 | 57 |
|
51 | 58 | def _provider_cache_key(provider: MCPProviderT) -> str: |
@@ -289,7 +296,7 @@ async def _call_tool_async( |
289 | 296 | session = await self._get_or_create_session(provider) |
290 | 297 | result = await session.call_tool(tool_name, arguments) |
291 | 298 |
|
292 | | - content = _serialize_tool_result_content(result) |
| 299 | + content = _coerce_tool_result_content(result) |
293 | 300 | is_error = getattr(result, "isError", None) |
294 | 301 | if is_error is None: |
295 | 302 | is_error = getattr(result, "is_error", False) |
@@ -467,31 +474,195 @@ def _coerce_tool_definition(tool: Any, tool_definition_cls: type[MCPToolDefiniti |
467 | 474 | return tool_definition_cls(name=name, description=description, input_schema=input_schema) |
468 | 475 |
|
469 | 476 |
|
470 | | -def _serialize_tool_result_content(result: Any) -> str: |
471 | | - """Serialize tool result content to a string.""" |
| 477 | +def _coerce_tool_result_content(result: Any) -> str | list[dict[str, Any]]: |
| 478 | + """Coerce MCP tool result content while preserving image blocks.""" |
472 | 479 | content = getattr(result, "content", result) |
473 | 480 | if content is None: |
474 | 481 | return "" |
475 | 482 | if isinstance(content, str): |
476 | 483 | return content |
477 | 484 | if isinstance(content, dict): |
| 485 | + if _is_image_url_block(content): |
| 486 | + return [_coerce_image_url_block(content)] |
| 487 | + if _is_image_content(content) or _has_base64_image_payload(content): |
| 488 | + return [_build_image_url_block(content)] |
| 489 | + if _is_text_content(content): |
| 490 | + return str(content.get("text", "")) |
478 | 491 | return json.dumps(content) |
| 492 | + if _is_image_content(content) or _has_base64_image_payload(content): |
| 493 | + return [_build_image_url_block(content)] |
| 494 | + if _is_text_content(content): |
| 495 | + return str(_get_content_field(content, "text", default="")) |
479 | 496 | if isinstance(content, list): |
480 | | - parts: list[str] = [] |
| 497 | + blocks: list[dict[str, Any]] = [] |
| 498 | + has_image = False |
481 | 499 | for item in content: |
482 | | - if isinstance(item, str): |
483 | | - parts.append(item) |
484 | | - continue |
485 | | - if isinstance(item, dict): |
486 | | - if item.get("type") == "text": |
487 | | - parts.append(str(item.get("text", ""))) |
488 | | - else: |
489 | | - parts.append(json.dumps(item)) |
490 | | - continue |
491 | | - text_value = getattr(item, "text", None) |
492 | | - if text_value is not None: |
493 | | - parts.append(str(text_value)) |
494 | | - else: |
495 | | - parts.append(str(item)) |
496 | | - return "\n".join(parts) |
| 500 | + block = _coerce_tool_result_content_item(item) |
| 501 | + blocks.append(block) |
| 502 | + has_image = has_image or block.get("type") == "image_url" |
| 503 | + if has_image: |
| 504 | + return blocks |
| 505 | + return "\n".join(block.get("text", "") for block in blocks) |
497 | 506 | return str(content) |
| 507 | + |
| 508 | + |
| 509 | +def _coerce_tool_result_content_item(item: Any) -> dict[str, Any]: |
| 510 | + """Coerce a single MCP content item to an internal ChatML-style block.""" |
| 511 | + if isinstance(item, str): |
| 512 | + return _build_text_block(item) |
| 513 | + if _is_image_url_block(item): |
| 514 | + return _coerce_image_url_block(item) |
| 515 | + if _is_image_content(item) or _has_base64_image_payload(item): |
| 516 | + return _build_image_url_block(item) |
| 517 | + if _is_text_content(item): |
| 518 | + return _build_text_block(_get_content_field(item, "text", default="")) |
| 519 | + if isinstance(item, dict): |
| 520 | + return _build_text_block(json.dumps(item)) |
| 521 | + |
| 522 | + text_value = getattr(item, "text", None) |
| 523 | + if text_value is not None: |
| 524 | + return _build_text_block(text_value) |
| 525 | + return _build_text_block(item) |
| 526 | + |
| 527 | + |
| 528 | +def _is_text_content(item: Any) -> bool: |
| 529 | + return _get_content_field(item, "type") == "text" |
| 530 | + |
| 531 | + |
| 532 | +def _is_image_content(item: Any) -> bool: |
| 533 | + return _get_content_field(item, "type") == "image" |
| 534 | + |
| 535 | + |
| 536 | +def _is_image_url_block(item: Any) -> bool: |
| 537 | + return isinstance(item, dict) and item.get("type") == "image_url" |
| 538 | + |
| 539 | + |
| 540 | +def _has_base64_image_payload(item: Any) -> bool: |
| 541 | + data = _get_content_field(item, "data", "b64_json", "base64") |
| 542 | + if not isinstance(data, str) or not data: |
| 543 | + return False |
| 544 | + |
| 545 | + mime_type = _get_content_field(item, "mimeType", "mime_type", "media_type") |
| 546 | + if isinstance(mime_type, str) and mime_type: |
| 547 | + return _is_image_mime_type(mime_type) |
| 548 | + |
| 549 | + data_uri_mime_type = _extract_data_uri_mime_type(data) |
| 550 | + return data_uri_mime_type is not None and _is_image_mime_type(data_uri_mime_type) |
| 551 | + |
| 552 | + |
| 553 | +def _coerce_image_url_block(block: dict[str, Any]) -> dict[str, Any]: |
| 554 | + image_url = block.get("image_url") |
| 555 | + if isinstance(image_url, str): |
| 556 | + image_url = {"url": image_url} |
| 557 | + elif not isinstance(image_url, dict): |
| 558 | + raise MCPToolError("MCP image_url block must contain an image_url dict or string.") |
| 559 | + |
| 560 | + url = image_url.get("url") |
| 561 | + if not isinstance(url, str) or not url: |
| 562 | + raise MCPToolError("MCP image_url block must contain a non-empty string URL.") |
| 563 | + if url.startswith(("http://", "https://")): |
| 564 | + return {"type": "image_url", "image_url": image_url} |
| 565 | + if url.startswith("data:"): |
| 566 | + _extract_mime_type_from_data_uri(url) |
| 567 | + _coerce_base64_image_data(url) |
| 568 | + return {"type": "image_url", "image_url": image_url} |
| 569 | + |
| 570 | + return _build_image_url_block({"base64": url}) |
| 571 | + |
| 572 | + |
| 573 | +def _build_image_url_block(item: Any) -> dict[str, Any]: |
| 574 | + data = _get_content_field(item, "data", "b64_json", "base64") |
| 575 | + mime_type = _get_content_field(item, "mimeType", "mime_type", "media_type") |
| 576 | + if not isinstance(data, str) or not data: |
| 577 | + raise MCPToolError("MCP image content is missing base64 data.") |
| 578 | + mime_type = _coerce_image_mime_type(data, mime_type) |
| 579 | + base64_data = _coerce_base64_image_data(data) |
| 580 | + |
| 581 | + return { |
| 582 | + "type": "image_url", |
| 583 | + "image_url": {"url": f"data:{mime_type};base64,{base64_data}"}, |
| 584 | + } |
| 585 | + |
| 586 | + |
| 587 | +def _coerce_image_mime_type(data: str, mime_type: Any) -> str: |
| 588 | + if isinstance(mime_type, str) and mime_type: |
| 589 | + if not _is_image_mime_type(mime_type): |
| 590 | + raise MCPToolError(f"MCP image content must use an image MIME type, got {mime_type!r}.") |
| 591 | + return mime_type |
| 592 | + |
| 593 | + data_uri_mime_type = _extract_mime_type_from_data_uri(data) |
| 594 | + if data_uri_mime_type is not None: |
| 595 | + return data_uri_mime_type |
| 596 | + |
| 597 | + try: |
| 598 | + return f"image/{detect_image_format(decode_base64_image(data)).value}" |
| 599 | + except ValueError as exc: |
| 600 | + raise MCPToolError("MCP image content is missing a MIME type.") from exc |
| 601 | + |
| 602 | + |
| 603 | +def _coerce_base64_image_data(data: str) -> str: |
| 604 | + try: |
| 605 | + base64_data = extract_base64_from_data_uri(data) |
| 606 | + decode_base64_image(base64_data) |
| 607 | + return base64_data |
| 608 | + except ValueError as exc: |
| 609 | + raise MCPToolError("MCP image content has invalid base64 data.") from exc |
| 610 | + |
| 611 | + |
| 612 | +def _extract_mime_type_from_data_uri(data: str) -> str | None: |
| 613 | + mime_type = _extract_data_uri_mime_type(data) |
| 614 | + if mime_type is None: |
| 615 | + return None |
| 616 | + if not _is_image_mime_type(mime_type): |
| 617 | + raise MCPToolError(f"MCP image content data URI must use an image MIME type, got {mime_type!r}.") |
| 618 | + return mime_type |
| 619 | + |
| 620 | + |
| 621 | +def _extract_data_uri_mime_type(data: str) -> str | None: |
| 622 | + match = _DATA_URI_MIME_TYPE_RE.match(data) |
| 623 | + if match is None: |
| 624 | + return None |
| 625 | + return match.group("mime_type") |
| 626 | + |
| 627 | + |
| 628 | +def _is_image_mime_type(mime_type: str) -> bool: |
| 629 | + return mime_type.lower().startswith("image/") |
| 630 | + |
| 631 | + |
| 632 | +def _get_content_field(item: Any, *names: str, default: Any = None) -> Any: |
| 633 | + if isinstance(item, dict): |
| 634 | + for name in names: |
| 635 | + if name in item: |
| 636 | + return item[name] |
| 637 | + return default |
| 638 | + |
| 639 | + for name in names: |
| 640 | + if hasattr(item, name): |
| 641 | + return getattr(item, name) |
| 642 | + |
| 643 | + model_dump = getattr(item, "model_dump", None) |
| 644 | + if callable(model_dump): |
| 645 | + return _get_content_field_from_dump(model_dump, names, default) |
| 646 | + |
| 647 | + dict_dump = getattr(item, "dict", None) |
| 648 | + if callable(dict_dump): |
| 649 | + return _get_content_field_from_dump(dict_dump, names, default) |
| 650 | + |
| 651 | + return default |
| 652 | + |
| 653 | + |
| 654 | +def _get_content_field_from_dump(dump_method: Callable[..., Any], names: tuple[str, ...], default: Any) -> Any: |
| 655 | + for kwargs in ({"by_alias": True}, {}): |
| 656 | + try: |
| 657 | + dumped = dump_method(**kwargs) |
| 658 | + except TypeError: |
| 659 | + continue |
| 660 | + if isinstance(dumped, dict): |
| 661 | + for name in names: |
| 662 | + if name in dumped: |
| 663 | + return dumped[name] |
| 664 | + return default |
| 665 | + |
| 666 | + |
| 667 | +def _build_text_block(value: Any) -> dict[str, Any]: |
| 668 | + return {"type": "text", "text": str(value)} |
0 commit comments