55import json
66from dataclasses import asdict , dataclass , field
77from pathlib import Path
8- from typing import Any , Callable , Mapping
8+ from typing import Any , Callable , Mapping , Sequence
99
1010import numpy as np
1111
@@ -22,6 +22,10 @@ class PredictionArtifactMetadata:
2222 input_shape : tuple [int , ...] | None = None
2323 final_shape : tuple [int , ...] | None = None
2424 crop_pad : tuple [tuple [int , int ], ...] | None = None
25+ transpose : tuple [int , ...] | None = None
26+ model_architecture : str | None = None
27+ model_output_identity : str | None = None
28+ decode_after_inference : bool | None = None
2529 chunk_shape : tuple [int , ...] | None = None
2630 halo : tuple [int , ...] | None = None
2731 channel_order : tuple [str , ...] | None = None
@@ -31,6 +35,86 @@ class PredictionArtifactMetadata:
3135 extra : Mapping [str , Any ] = field (default_factory = dict )
3236
3337
38+ def _cfg_get (obj : Any , path : str , default : Any = None ) -> Any :
39+ node = obj
40+ for part in path .split ("." ):
41+ if node is None :
42+ return default
43+ if isinstance (node , Mapping ):
44+ node = node .get (part , default )
45+ else :
46+ node = getattr (node , part , default )
47+ return node
48+
49+
50+ def _tuple_or_none (value : Sequence [Any ] | None ) -> tuple [int , ...] | None :
51+ if value in (None , [], ()):
52+ return None
53+ return tuple (int (v ) for v in value )
54+
55+
56+ def _model_output_identity (cfg : Any , output_head : str | None ) -> str | None :
57+ parts : list [str ] = []
58+ if output_head :
59+ parts .append (f"head={ output_head } " )
60+ else :
61+ primary_head = _cfg_get (cfg , "model.primary_head" )
62+ if primary_head :
63+ parts .append (f"primary_head={ primary_head } " )
64+
65+ select_channel = _cfg_get (cfg , "inference.select_channel" )
66+ if select_channel is not None :
67+ parts .append (f"select_channel={ select_channel } " )
68+
69+ return ";" .join (parts ) if parts else None
70+
71+
72+ def build_prediction_artifact_metadata (
73+ cfg : Any ,
74+ * ,
75+ image_path : str | None = None ,
76+ checkpoint_path : str | None = None ,
77+ output_head : str | None = None ,
78+ input_shape : Sequence [int ] | None = None ,
79+ final_shape : Sequence [int ] | None = None ,
80+ crop_pad : Sequence [Sequence [int ]] | None = None ,
81+ chunk_shape : Sequence [int ] | None = None ,
82+ halo : Sequence [int ] | None = None ,
83+ intensity_scale : float | None = None ,
84+ intensity_dtype : str | None = None ,
85+ extra : Mapping [str , Any ] | None = None ,
86+ ) -> PredictionArtifactMetadata :
87+ """Build standard metadata for raw prediction artifacts."""
88+ transform_cfg = _cfg_get (cfg , "inference.prediction_transform" )
89+ transform_enabled = bool (getattr (transform_cfg , "enabled" , False ))
90+ if intensity_scale is None and transform_enabled :
91+ intensity_scale = float (getattr (transform_cfg , "intensity_scale" , - 1.0 ))
92+ if intensity_dtype is None and transform_enabled :
93+ intensity_dtype = getattr (transform_cfg , "intensity_dtype" , None )
94+
95+ return PredictionArtifactMetadata (
96+ image_path = image_path ,
97+ checkpoint_path = str (checkpoint_path ) if checkpoint_path is not None else None ,
98+ output_head = output_head ,
99+ input_shape = _tuple_or_none (input_shape ),
100+ final_shape = _tuple_or_none (final_shape ),
101+ crop_pad = (
102+ tuple ((int (pair [0 ]), int (pair [1 ])) for pair in crop_pad )
103+ if crop_pad is not None
104+ else None
105+ ),
106+ transpose = _tuple_or_none (_cfg_get (cfg , "data.data_transform.val_transpose" )),
107+ model_architecture = _cfg_get (cfg , "model.arch.type" ),
108+ model_output_identity = _model_output_identity (cfg , output_head ),
109+ decode_after_inference = bool (_cfg_get (cfg , "inference.decode_after_inference" , True )),
110+ chunk_shape = _tuple_or_none (chunk_shape ),
111+ halo = _tuple_or_none (halo ),
112+ intensity_scale = intensity_scale ,
113+ intensity_dtype = intensity_dtype ,
114+ extra = extra or {},
115+ )
116+
117+
34118def _json_attr (value : Any ) -> Any :
35119 if value is None or isinstance (value , (str , int , float , bool )):
36120 return value
@@ -140,6 +224,7 @@ def read_prediction_artifact(
140224
141225__all__ = [
142226 "PredictionArtifactMetadata" ,
227+ "build_prediction_artifact_metadata" ,
143228 "read_prediction_artifact" ,
144229 "write_prediction_artifact" ,
145230 "write_prediction_artifact_attrs" ,
0 commit comments