|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | 14 | from typing import Any, Dict, Iterable, Optional, Sequence |
| 15 | +from uuid import uuid4 |
15 | 16 |
|
16 | 17 | import numpy as np |
17 | 18 | import sidpy as sid |
|
39 | 40 | ] |
40 | 41 |
|
41 | 42 | SHO_PARAMETER_LABELS = ["amplitude", "resonance_frequency", "quality_factor", "phase"] |
| 43 | +DATASET_REGISTRY: Dict[str, sid.Dataset] = {} |
42 | 44 |
|
43 | 45 |
|
44 | 46 | def loop_fit_function(vdc: Sequence[float], *coef_vec: float) -> np.ndarray: |
@@ -216,6 +218,66 @@ def _as_builtin(value: Any) -> Any: |
216 | 218 | return value |
217 | 219 |
|
218 | 220 |
|
| 221 | +def _normalize_data_type(data_type: Optional[str]) -> str: |
| 222 | + return "UNKNOWN" if data_type is None else str(data_type) |
| 223 | + |
| 224 | + |
| 225 | +def _dataset_dimension_payload(dimension: sid.Dimension, axis: int) -> Dict[str, Any]: |
| 226 | + values = np.asarray(dimension) |
| 227 | + return { |
| 228 | + "axis": axis, |
| 229 | + "name": dimension.name, |
| 230 | + "quantity": dimension.quantity, |
| 231 | + "units": dimension.units, |
| 232 | + "dimension_type": dimension.dimension_type.name, |
| 233 | + "length": int(values.size), |
| 234 | + "values": values.tolist(), |
| 235 | + } |
| 236 | + |
| 237 | + |
| 238 | +def _dataset_payload(dataset: sid.Dataset, dataset_id: Optional[str] = None) -> Dict[str, Any]: |
| 239 | + payload = { |
| 240 | + "shape": list(dataset.shape), |
| 241 | + "ndim": int(dataset.ndim), |
| 242 | + "title": dataset.title, |
| 243 | + "quantity": dataset.quantity, |
| 244 | + "units": dataset.units, |
| 245 | + "data_type": dataset.data_type.name, |
| 246 | + "modality": dataset.modality, |
| 247 | + "source": dataset.source, |
| 248 | + "metadata": _as_builtin(dataset.metadata), |
| 249 | + "original_metadata": _as_builtin(dataset.original_metadata), |
| 250 | + "dimensions": [_dataset_dimension_payload(dataset._axes[axis], axis) for axis in sorted(dataset._axes)], |
| 251 | + } |
| 252 | + if dataset_id is not None: |
| 253 | + payload["dataset_id"] = dataset_id |
| 254 | + return payload |
| 255 | + |
| 256 | + |
| 257 | +def _store_dataset(dataset: sid.Dataset, dataset_id: Optional[str] = None) -> str: |
| 258 | + if dataset_id is None: |
| 259 | + dataset_id = str(uuid4()) |
| 260 | + DATASET_REGISTRY[dataset_id] = dataset |
| 261 | + return dataset_id |
| 262 | + |
| 263 | + |
| 264 | +def _get_dataset(dataset_id: str) -> sid.Dataset: |
| 265 | + try: |
| 266 | + return DATASET_REGISTRY[dataset_id] |
| 267 | + except KeyError as exc: |
| 268 | + raise KeyError(f"Unknown dataset_id '{dataset_id}'. Create or register a dataset first.") from exc |
| 269 | + |
| 270 | + |
| 271 | +def _merge_nested_dict(base: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]: |
| 272 | + merged = dict(base) |
| 273 | + for key, value in updates.items(): |
| 274 | + if isinstance(value, dict) and isinstance(merged.get(key), dict): |
| 275 | + merged[key] = _merge_nested_dict(merged[key], value) |
| 276 | + else: |
| 277 | + merged[key] = value |
| 278 | + return merged |
| 279 | + |
| 280 | + |
219 | 281 | def _build_dataset( |
220 | 282 | data: Sequence[Any], |
221 | 283 | spectral_axis: Sequence[float], |
@@ -262,6 +324,157 @@ def _build_dataset( |
262 | 324 | return dataset |
263 | 325 |
|
264 | 326 |
|
| 327 | +def create_dataset( |
| 328 | + data: Sequence[Any], |
| 329 | + *, |
| 330 | + dataset_name: str = "sidpy_dataset", |
| 331 | + data_type: Optional[str] = None, |
| 332 | + quantity: str = "generic", |
| 333 | + units: str = "generic", |
| 334 | + modality: str = "generic", |
| 335 | + source: str = "generic", |
| 336 | + dimensions: Optional[Sequence[Dict[str, Any]]] = None, |
| 337 | + metadata: Optional[Dict[str, Any]] = None, |
| 338 | + original_metadata: Optional[Dict[str, Any]] = None, |
| 339 | + dataset_id: Optional[str] = None, |
| 340 | +) -> Dict[str, Any]: |
| 341 | + """Create and register a sidpy.Dataset from nested array-like data.""" |
| 342 | + dataset = sid.Dataset.from_array( |
| 343 | + data, |
| 344 | + title=dataset_name, |
| 345 | + datatype=_normalize_data_type(data_type), |
| 346 | + quantity=quantity, |
| 347 | + units=units, |
| 348 | + modality=modality, |
| 349 | + source=source, |
| 350 | + ) |
| 351 | + |
| 352 | + if dimensions is not None: |
| 353 | + if len(dimensions) != dataset.ndim: |
| 354 | + raise ValueError( |
| 355 | + "dimensions must provide one entry per dataset axis. " |
| 356 | + f"Received {len(dimensions)} dimensions for dataset.ndim={dataset.ndim}." |
| 357 | + ) |
| 358 | + specified_axes = [int(dimension_spec.get("axis", axis)) for axis, dimension_spec in enumerate(dimensions)] |
| 359 | + expected_axes = set(range(dataset.ndim)) |
| 360 | + if len(set(specified_axes)) != len(specified_axes): |
| 361 | + raise ValueError("Each dimension axis may only be specified once.") |
| 362 | + if set(specified_axes) != expected_axes: |
| 363 | + raise ValueError( |
| 364 | + "dimensions must cover every dataset axis exactly once. " |
| 365 | + f"Received axes {specified_axes} for expected axes {sorted(expected_axes)}." |
| 366 | + ) |
| 367 | + for axis, dimension_spec in enumerate(dimensions): |
| 368 | + axis_index = int(dimension_spec.get("axis", axis)) |
| 369 | + values = dimension_spec.get("values", np.arange(dataset.shape[axis_index])) |
| 370 | + name = dimension_spec.get("name", f"dim_{axis_index}") |
| 371 | + quantity_value = dimension_spec.get("quantity", name) |
| 372 | + units_value = dimension_spec.get("units", "generic") |
| 373 | + dim_type = dimension_spec.get("dimension_type", "UNKNOWN") |
| 374 | + dataset.set_dimension( |
| 375 | + axis_index, |
| 376 | + sid.Dimension( |
| 377 | + values, |
| 378 | + name=name, |
| 379 | + quantity=quantity_value, |
| 380 | + units=units_value, |
| 381 | + dimension_type=dim_type, |
| 382 | + ), |
| 383 | + ) |
| 384 | + |
| 385 | + if metadata is not None: |
| 386 | + dataset.metadata = dict(metadata) |
| 387 | + if original_metadata is not None: |
| 388 | + dataset.original_metadata = dict(original_metadata) |
| 389 | + |
| 390 | + dataset_id = _store_dataset(dataset, dataset_id=dataset_id) |
| 391 | + return _dataset_payload(dataset, dataset_id=dataset_id) |
| 392 | + |
| 393 | + |
| 394 | +def get_dataset(dataset_id: str) -> Dict[str, Any]: |
| 395 | + """Return a JSON-friendly summary of a registered dataset.""" |
| 396 | + dataset = _get_dataset(dataset_id) |
| 397 | + return _dataset_payload(dataset, dataset_id=dataset_id) |
| 398 | + |
| 399 | + |
| 400 | +def list_datasets() -> Dict[str, Any]: |
| 401 | + """List dataset ids currently stored in the in-memory MCP registry.""" |
| 402 | + return { |
| 403 | + "datasets": [ |
| 404 | + { |
| 405 | + "dataset_id": dataset_id, |
| 406 | + "title": dataset.title, |
| 407 | + "shape": list(dataset.shape), |
| 408 | + "data_type": dataset.data_type.name, |
| 409 | + } |
| 410 | + for dataset_id, dataset in DATASET_REGISTRY.items() |
| 411 | + ] |
| 412 | + } |
| 413 | + |
| 414 | + |
| 415 | +def add_metadata( |
| 416 | + dataset_id: str, |
| 417 | + metadata: Dict[str, Any], |
| 418 | + *, |
| 419 | + merge: bool = True, |
| 420 | + target: str = "metadata", |
| 421 | +) -> Dict[str, Any]: |
| 422 | + """Add or replace metadata on a registered dataset.""" |
| 423 | + dataset = _get_dataset(dataset_id) |
| 424 | + target_name = str(target).lower() |
| 425 | + if target_name not in {"metadata", "original_metadata"}: |
| 426 | + raise ValueError("target must be either 'metadata' or 'original_metadata'.") |
| 427 | + |
| 428 | + current = getattr(dataset, target_name) |
| 429 | + next_value = _merge_nested_dict(current, metadata) if merge else dict(metadata) |
| 430 | + setattr(dataset, target_name, next_value) |
| 431 | + return _dataset_payload(dataset, dataset_id=dataset_id) |
| 432 | + |
| 433 | + |
| 434 | +def update_dimension( |
| 435 | + dataset_id: str, |
| 436 | + axis: int, |
| 437 | + *, |
| 438 | + values: Sequence[float], |
| 439 | + name: Optional[str] = None, |
| 440 | + quantity: Optional[str] = None, |
| 441 | + units: Optional[str] = None, |
| 442 | + dimension_type: Optional[str] = None, |
| 443 | +) -> Dict[str, Any]: |
| 444 | + """Replace one dataset dimension using sidpy.Dimension and Dataset.set_dimension.""" |
| 445 | + dataset = _get_dataset(dataset_id) |
| 446 | + axis = int(axis) |
| 447 | + existing_dimension = dataset._axes[axis] |
| 448 | + replacement = sid.Dimension( |
| 449 | + values, |
| 450 | + name=name or existing_dimension.name, |
| 451 | + quantity=quantity or existing_dimension.quantity, |
| 452 | + units=units or existing_dimension.units, |
| 453 | + dimension_type=existing_dimension.dimension_type if dimension_type is None else dimension_type, |
| 454 | + ) |
| 455 | + dataset.set_dimension(axis, replacement) |
| 456 | + return _dataset_payload(dataset, dataset_id=dataset_id) |
| 457 | + |
| 458 | + |
| 459 | +def rename_dimension(dataset_id: str, axis: int, name: str) -> Dict[str, Any]: |
| 460 | + """Rename one registered dataset dimension.""" |
| 461 | + dataset = _get_dataset(dataset_id) |
| 462 | + dataset.rename_dimension(int(axis), name) |
| 463 | + return _dataset_payload(dataset, dataset_id=dataset_id) |
| 464 | + |
| 465 | + |
| 466 | +def remove_dataset(dataset_id: str) -> Dict[str, Any]: |
| 467 | + """Remove a dataset from the in-memory registry.""" |
| 468 | + dataset = _get_dataset(dataset_id) |
| 469 | + payload = { |
| 470 | + "dataset_id": dataset_id, |
| 471 | + "title": dataset.title, |
| 472 | + "removed": True, |
| 473 | + } |
| 474 | + del DATASET_REGISTRY[dataset_id] |
| 475 | + return payload |
| 476 | + |
| 477 | + |
265 | 478 | def _package_result(result: Any) -> Dict[str, Any]: |
266 | 479 | """Normalize fitter outputs into a JSON-friendly payload.""" |
267 | 480 | if isinstance(result, tuple): |
@@ -386,12 +599,92 @@ def fit_sho_response( |
386 | 599 |
|
387 | 600 |
|
388 | 601 | def create_mcp_server(server_name: str = "sidpy-beps-fitting"): |
389 | | - """Create an MCP server exposing the BEPS loop and SHO fitting tools.""" |
| 602 | + """Create an MCP server exposing sidpy dataset and fitting tools.""" |
390 | 603 | if FastMCP is None: # pragma: no cover - optional runtime dependency |
391 | 604 | raise ImportError("The 'mcp' package is required to create the BEPS MCP server.") |
392 | 605 |
|
393 | 606 | server = FastMCP(server_name) |
394 | 607 |
|
| 608 | + @server.tool() |
| 609 | + def create_dataset_tool( |
| 610 | + data: Sequence[Any], |
| 611 | + dataset_name: str = "sidpy_dataset", |
| 612 | + data_type: Optional[str] = None, |
| 613 | + quantity: str = "generic", |
| 614 | + units: str = "generic", |
| 615 | + modality: str = "generic", |
| 616 | + source: str = "generic", |
| 617 | + dimensions: Optional[Sequence[Dict[str, Any]]] = None, |
| 618 | + metadata: Optional[Dict[str, Any]] = None, |
| 619 | + original_metadata: Optional[Dict[str, Any]] = None, |
| 620 | + dataset_id: Optional[str] = None, |
| 621 | + ) -> Dict[str, Any]: |
| 622 | + """Create a sidpy.Dataset and store it in the MCP server registry.""" |
| 623 | + return create_dataset( |
| 624 | + data, |
| 625 | + dataset_name=dataset_name, |
| 626 | + data_type=data_type, |
| 627 | + quantity=quantity, |
| 628 | + units=units, |
| 629 | + modality=modality, |
| 630 | + source=source, |
| 631 | + dimensions=dimensions, |
| 632 | + metadata=metadata, |
| 633 | + original_metadata=original_metadata, |
| 634 | + dataset_id=dataset_id, |
| 635 | + ) |
| 636 | + |
| 637 | + @server.tool() |
| 638 | + def get_dataset_tool(dataset_id: str) -> Dict[str, Any]: |
| 639 | + """Return a registered dataset summary including dimensions and metadata.""" |
| 640 | + return get_dataset(dataset_id) |
| 641 | + |
| 642 | + @server.tool() |
| 643 | + def list_datasets_tool() -> Dict[str, Any]: |
| 644 | + """List the registered dataset ids in this MCP server process.""" |
| 645 | + return list_datasets() |
| 646 | + |
| 647 | + @server.tool() |
| 648 | + def add_metadata_tool( |
| 649 | + dataset_id: str, |
| 650 | + metadata: Dict[str, Any], |
| 651 | + merge: bool = True, |
| 652 | + target: str = "metadata", |
| 653 | + ) -> Dict[str, Any]: |
| 654 | + """Add or replace metadata on a registered sidpy dataset.""" |
| 655 | + return add_metadata(dataset_id, metadata, merge=merge, target=target) |
| 656 | + |
| 657 | + @server.tool() |
| 658 | + def update_dimension_tool( |
| 659 | + dataset_id: str, |
| 660 | + axis: int, |
| 661 | + values: Sequence[float], |
| 662 | + name: Optional[str] = None, |
| 663 | + quantity: Optional[str] = None, |
| 664 | + units: Optional[str] = None, |
| 665 | + dimension_type: Optional[str] = None, |
| 666 | + ) -> Dict[str, Any]: |
| 667 | + """Replace one dataset dimension with new values and optional metadata.""" |
| 668 | + return update_dimension( |
| 669 | + dataset_id, |
| 670 | + axis, |
| 671 | + values=values, |
| 672 | + name=name, |
| 673 | + quantity=quantity, |
| 674 | + units=units, |
| 675 | + dimension_type=dimension_type, |
| 676 | + ) |
| 677 | + |
| 678 | + @server.tool() |
| 679 | + def rename_dimension_tool(dataset_id: str, axis: int, name: str) -> Dict[str, Any]: |
| 680 | + """Rename one dataset dimension.""" |
| 681 | + return rename_dimension(dataset_id, axis, name) |
| 682 | + |
| 683 | + @server.tool() |
| 684 | + def remove_dataset_tool(dataset_id: str) -> Dict[str, Any]: |
| 685 | + """Remove a dataset from the in-memory MCP registry.""" |
| 686 | + return remove_dataset(dataset_id) |
| 687 | + |
395 | 688 | @server.tool() |
396 | 689 | def fit_beps_loops_tool( |
397 | 690 | data: Sequence[Any], |
@@ -468,15 +761,23 @@ def main(): |
468 | 761 |
|
469 | 762 |
|
470 | 763 | __all__ = [ |
| 764 | + "DATASET_REGISTRY", |
471 | 765 | "LOOP_PARAMETER_LABELS", |
472 | 766 | "SHO_PARAMETER_LABELS", |
473 | 767 | "SHO_fit_flattened", |
| 768 | + "add_metadata", |
474 | 769 | "calculate_loop_centroid", |
| 770 | + "create_dataset", |
475 | 771 | "create_mcp_server", |
476 | 772 | "fit_beps_loops", |
477 | 773 | "fit_sho_response", |
478 | 774 | "generate_guess", |
| 775 | + "get_dataset", |
| 776 | + "list_datasets", |
479 | 777 | "loop_fit_function", |
480 | 778 | "main", |
| 779 | + "remove_dataset", |
| 780 | + "rename_dimension", |
481 | 781 | "sho_guess_fn", |
| 782 | + "update_dimension", |
482 | 783 | ] |
0 commit comments