Skip to content

Commit 3e546b8

Browse files
author
Rama Vasudevan
committed
Add core sidpy MCP dataset tools
1 parent 273e048 commit 3e546b8

2 files changed

Lines changed: 457 additions & 1 deletion

File tree

sidpy/proc/mcp_server_beps.py

Lines changed: 302 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from typing import Any, Dict, Iterable, Optional, Sequence
15+
from uuid import uuid4
1516

1617
import numpy as np
1718
import sidpy as sid
@@ -39,6 +40,7 @@
3940
]
4041

4142
SHO_PARAMETER_LABELS = ["amplitude", "resonance_frequency", "quality_factor", "phase"]
43+
DATASET_REGISTRY: Dict[str, sid.Dataset] = {}
4244

4345

4446
def loop_fit_function(vdc: Sequence[float], *coef_vec: float) -> np.ndarray:
@@ -216,6 +218,66 @@ def _as_builtin(value: Any) -> Any:
216218
return value
217219

218220

221+
def _normalize_data_type(data_type: Optional[str]) -> str:
222+
return "UNKNOWN" if data_type is None else str(data_type)
223+
224+
225+
def _dataset_dimension_payload(dimension: sid.Dimension, axis: int) -> Dict[str, Any]:
226+
values = np.asarray(dimension)
227+
return {
228+
"axis": axis,
229+
"name": dimension.name,
230+
"quantity": dimension.quantity,
231+
"units": dimension.units,
232+
"dimension_type": dimension.dimension_type.name,
233+
"length": int(values.size),
234+
"values": values.tolist(),
235+
}
236+
237+
238+
def _dataset_payload(dataset: sid.Dataset, dataset_id: Optional[str] = None) -> Dict[str, Any]:
239+
payload = {
240+
"shape": list(dataset.shape),
241+
"ndim": int(dataset.ndim),
242+
"title": dataset.title,
243+
"quantity": dataset.quantity,
244+
"units": dataset.units,
245+
"data_type": dataset.data_type.name,
246+
"modality": dataset.modality,
247+
"source": dataset.source,
248+
"metadata": _as_builtin(dataset.metadata),
249+
"original_metadata": _as_builtin(dataset.original_metadata),
250+
"dimensions": [_dataset_dimension_payload(dataset._axes[axis], axis) for axis in sorted(dataset._axes)],
251+
}
252+
if dataset_id is not None:
253+
payload["dataset_id"] = dataset_id
254+
return payload
255+
256+
257+
def _store_dataset(dataset: sid.Dataset, dataset_id: Optional[str] = None) -> str:
258+
if dataset_id is None:
259+
dataset_id = str(uuid4())
260+
DATASET_REGISTRY[dataset_id] = dataset
261+
return dataset_id
262+
263+
264+
def _get_dataset(dataset_id: str) -> sid.Dataset:
265+
try:
266+
return DATASET_REGISTRY[dataset_id]
267+
except KeyError as exc:
268+
raise KeyError(f"Unknown dataset_id '{dataset_id}'. Create or register a dataset first.") from exc
269+
270+
271+
def _merge_nested_dict(base: Dict[str, Any], updates: Dict[str, Any]) -> Dict[str, Any]:
272+
merged = dict(base)
273+
for key, value in updates.items():
274+
if isinstance(value, dict) and isinstance(merged.get(key), dict):
275+
merged[key] = _merge_nested_dict(merged[key], value)
276+
else:
277+
merged[key] = value
278+
return merged
279+
280+
219281
def _build_dataset(
220282
data: Sequence[Any],
221283
spectral_axis: Sequence[float],
@@ -262,6 +324,157 @@ def _build_dataset(
262324
return dataset
263325

264326

327+
def create_dataset(
328+
data: Sequence[Any],
329+
*,
330+
dataset_name: str = "sidpy_dataset",
331+
data_type: Optional[str] = None,
332+
quantity: str = "generic",
333+
units: str = "generic",
334+
modality: str = "generic",
335+
source: str = "generic",
336+
dimensions: Optional[Sequence[Dict[str, Any]]] = None,
337+
metadata: Optional[Dict[str, Any]] = None,
338+
original_metadata: Optional[Dict[str, Any]] = None,
339+
dataset_id: Optional[str] = None,
340+
) -> Dict[str, Any]:
341+
"""Create and register a sidpy.Dataset from nested array-like data."""
342+
dataset = sid.Dataset.from_array(
343+
data,
344+
title=dataset_name,
345+
datatype=_normalize_data_type(data_type),
346+
quantity=quantity,
347+
units=units,
348+
modality=modality,
349+
source=source,
350+
)
351+
352+
if dimensions is not None:
353+
if len(dimensions) != dataset.ndim:
354+
raise ValueError(
355+
"dimensions must provide one entry per dataset axis. "
356+
f"Received {len(dimensions)} dimensions for dataset.ndim={dataset.ndim}."
357+
)
358+
specified_axes = [int(dimension_spec.get("axis", axis)) for axis, dimension_spec in enumerate(dimensions)]
359+
expected_axes = set(range(dataset.ndim))
360+
if len(set(specified_axes)) != len(specified_axes):
361+
raise ValueError("Each dimension axis may only be specified once.")
362+
if set(specified_axes) != expected_axes:
363+
raise ValueError(
364+
"dimensions must cover every dataset axis exactly once. "
365+
f"Received axes {specified_axes} for expected axes {sorted(expected_axes)}."
366+
)
367+
for axis, dimension_spec in enumerate(dimensions):
368+
axis_index = int(dimension_spec.get("axis", axis))
369+
values = dimension_spec.get("values", np.arange(dataset.shape[axis_index]))
370+
name = dimension_spec.get("name", f"dim_{axis_index}")
371+
quantity_value = dimension_spec.get("quantity", name)
372+
units_value = dimension_spec.get("units", "generic")
373+
dim_type = dimension_spec.get("dimension_type", "UNKNOWN")
374+
dataset.set_dimension(
375+
axis_index,
376+
sid.Dimension(
377+
values,
378+
name=name,
379+
quantity=quantity_value,
380+
units=units_value,
381+
dimension_type=dim_type,
382+
),
383+
)
384+
385+
if metadata is not None:
386+
dataset.metadata = dict(metadata)
387+
if original_metadata is not None:
388+
dataset.original_metadata = dict(original_metadata)
389+
390+
dataset_id = _store_dataset(dataset, dataset_id=dataset_id)
391+
return _dataset_payload(dataset, dataset_id=dataset_id)
392+
393+
394+
def get_dataset(dataset_id: str) -> Dict[str, Any]:
395+
"""Return a JSON-friendly summary of a registered dataset."""
396+
dataset = _get_dataset(dataset_id)
397+
return _dataset_payload(dataset, dataset_id=dataset_id)
398+
399+
400+
def list_datasets() -> Dict[str, Any]:
401+
"""List dataset ids currently stored in the in-memory MCP registry."""
402+
return {
403+
"datasets": [
404+
{
405+
"dataset_id": dataset_id,
406+
"title": dataset.title,
407+
"shape": list(dataset.shape),
408+
"data_type": dataset.data_type.name,
409+
}
410+
for dataset_id, dataset in DATASET_REGISTRY.items()
411+
]
412+
}
413+
414+
415+
def add_metadata(
416+
dataset_id: str,
417+
metadata: Dict[str, Any],
418+
*,
419+
merge: bool = True,
420+
target: str = "metadata",
421+
) -> Dict[str, Any]:
422+
"""Add or replace metadata on a registered dataset."""
423+
dataset = _get_dataset(dataset_id)
424+
target_name = str(target).lower()
425+
if target_name not in {"metadata", "original_metadata"}:
426+
raise ValueError("target must be either 'metadata' or 'original_metadata'.")
427+
428+
current = getattr(dataset, target_name)
429+
next_value = _merge_nested_dict(current, metadata) if merge else dict(metadata)
430+
setattr(dataset, target_name, next_value)
431+
return _dataset_payload(dataset, dataset_id=dataset_id)
432+
433+
434+
def update_dimension(
435+
dataset_id: str,
436+
axis: int,
437+
*,
438+
values: Sequence[float],
439+
name: Optional[str] = None,
440+
quantity: Optional[str] = None,
441+
units: Optional[str] = None,
442+
dimension_type: Optional[str] = None,
443+
) -> Dict[str, Any]:
444+
"""Replace one dataset dimension using sidpy.Dimension and Dataset.set_dimension."""
445+
dataset = _get_dataset(dataset_id)
446+
axis = int(axis)
447+
existing_dimension = dataset._axes[axis]
448+
replacement = sid.Dimension(
449+
values,
450+
name=name or existing_dimension.name,
451+
quantity=quantity or existing_dimension.quantity,
452+
units=units or existing_dimension.units,
453+
dimension_type=existing_dimension.dimension_type if dimension_type is None else dimension_type,
454+
)
455+
dataset.set_dimension(axis, replacement)
456+
return _dataset_payload(dataset, dataset_id=dataset_id)
457+
458+
459+
def rename_dimension(dataset_id: str, axis: int, name: str) -> Dict[str, Any]:
460+
"""Rename one registered dataset dimension."""
461+
dataset = _get_dataset(dataset_id)
462+
dataset.rename_dimension(int(axis), name)
463+
return _dataset_payload(dataset, dataset_id=dataset_id)
464+
465+
466+
def remove_dataset(dataset_id: str) -> Dict[str, Any]:
467+
"""Remove a dataset from the in-memory registry."""
468+
dataset = _get_dataset(dataset_id)
469+
payload = {
470+
"dataset_id": dataset_id,
471+
"title": dataset.title,
472+
"removed": True,
473+
}
474+
del DATASET_REGISTRY[dataset_id]
475+
return payload
476+
477+
265478
def _package_result(result: Any) -> Dict[str, Any]:
266479
"""Normalize fitter outputs into a JSON-friendly payload."""
267480
if isinstance(result, tuple):
@@ -386,12 +599,92 @@ def fit_sho_response(
386599

387600

388601
def create_mcp_server(server_name: str = "sidpy-beps-fitting"):
389-
"""Create an MCP server exposing the BEPS loop and SHO fitting tools."""
602+
"""Create an MCP server exposing sidpy dataset and fitting tools."""
390603
if FastMCP is None: # pragma: no cover - optional runtime dependency
391604
raise ImportError("The 'mcp' package is required to create the BEPS MCP server.")
392605

393606
server = FastMCP(server_name)
394607

608+
@server.tool()
609+
def create_dataset_tool(
610+
data: Sequence[Any],
611+
dataset_name: str = "sidpy_dataset",
612+
data_type: Optional[str] = None,
613+
quantity: str = "generic",
614+
units: str = "generic",
615+
modality: str = "generic",
616+
source: str = "generic",
617+
dimensions: Optional[Sequence[Dict[str, Any]]] = None,
618+
metadata: Optional[Dict[str, Any]] = None,
619+
original_metadata: Optional[Dict[str, Any]] = None,
620+
dataset_id: Optional[str] = None,
621+
) -> Dict[str, Any]:
622+
"""Create a sidpy.Dataset and store it in the MCP server registry."""
623+
return create_dataset(
624+
data,
625+
dataset_name=dataset_name,
626+
data_type=data_type,
627+
quantity=quantity,
628+
units=units,
629+
modality=modality,
630+
source=source,
631+
dimensions=dimensions,
632+
metadata=metadata,
633+
original_metadata=original_metadata,
634+
dataset_id=dataset_id,
635+
)
636+
637+
@server.tool()
638+
def get_dataset_tool(dataset_id: str) -> Dict[str, Any]:
639+
"""Return a registered dataset summary including dimensions and metadata."""
640+
return get_dataset(dataset_id)
641+
642+
@server.tool()
643+
def list_datasets_tool() -> Dict[str, Any]:
644+
"""List the registered dataset ids in this MCP server process."""
645+
return list_datasets()
646+
647+
@server.tool()
648+
def add_metadata_tool(
649+
dataset_id: str,
650+
metadata: Dict[str, Any],
651+
merge: bool = True,
652+
target: str = "metadata",
653+
) -> Dict[str, Any]:
654+
"""Add or replace metadata on a registered sidpy dataset."""
655+
return add_metadata(dataset_id, metadata, merge=merge, target=target)
656+
657+
@server.tool()
658+
def update_dimension_tool(
659+
dataset_id: str,
660+
axis: int,
661+
values: Sequence[float],
662+
name: Optional[str] = None,
663+
quantity: Optional[str] = None,
664+
units: Optional[str] = None,
665+
dimension_type: Optional[str] = None,
666+
) -> Dict[str, Any]:
667+
"""Replace one dataset dimension with new values and optional metadata."""
668+
return update_dimension(
669+
dataset_id,
670+
axis,
671+
values=values,
672+
name=name,
673+
quantity=quantity,
674+
units=units,
675+
dimension_type=dimension_type,
676+
)
677+
678+
@server.tool()
679+
def rename_dimension_tool(dataset_id: str, axis: int, name: str) -> Dict[str, Any]:
680+
"""Rename one dataset dimension."""
681+
return rename_dimension(dataset_id, axis, name)
682+
683+
@server.tool()
684+
def remove_dataset_tool(dataset_id: str) -> Dict[str, Any]:
685+
"""Remove a dataset from the in-memory MCP registry."""
686+
return remove_dataset(dataset_id)
687+
395688
@server.tool()
396689
def fit_beps_loops_tool(
397690
data: Sequence[Any],
@@ -468,15 +761,23 @@ def main():
468761

469762

470763
__all__ = [
764+
"DATASET_REGISTRY",
471765
"LOOP_PARAMETER_LABELS",
472766
"SHO_PARAMETER_LABELS",
473767
"SHO_fit_flattened",
768+
"add_metadata",
474769
"calculate_loop_centroid",
770+
"create_dataset",
475771
"create_mcp_server",
476772
"fit_beps_loops",
477773
"fit_sho_response",
478774
"generate_guess",
775+
"get_dataset",
776+
"list_datasets",
479777
"loop_fit_function",
480778
"main",
779+
"remove_dataset",
780+
"rename_dimension",
481781
"sho_guess_fn",
782+
"update_dimension",
482783
]

0 commit comments

Comments
 (0)