@@ -524,6 +524,10 @@ def check_dataset(
524524 target_splits = target_splits & dataset_splits
525525
526526 checksum_report = {}
527+ # Track shape of each Global feature across all checked splits/samples to
528+ # detect inconsistencies (e.g. a Global stored as a scalar in one sample
529+ # and as a vector in another).
530+ global_shape_observations : dict [str , dict [tuple , list [str ]]] = {}
527531 for split in sorted (target_splits ):
528532 dataset = datasetdict [split ]
529533 converter = converterdict [split ]
@@ -595,6 +599,19 @@ def check_dataset(
595599 issue ,
596600 )
597601
602+ # Record the observed shape of this Global so we can later
603+ # detect dimension mismatches across all checked samples
604+ # (across splits).
605+ if value is not None :
606+ try :
607+ shape = tuple (np .asarray (value ).shape )
608+ except Exception :
609+ shape = None
610+ if shape is not None :
611+ global_shape_observations .setdefault (
612+ global_name , {}
613+ ).setdefault (shape , []).append (f"{ split } [{ idx } ]" )
614+
598615 for time in sample .get_all_time_values ():
599616 local_bases = sample .get_base_names (time = time )
600617 for base in local_bases :
@@ -625,6 +642,25 @@ def check_dataset(
625642 issue ,
626643 )
627644
645+ # Report Globals whose dimension/shape is not consistent across all
646+ # checked samples (across splits).
647+ for global_name , shape_to_locations in global_shape_observations .items ():
648+ if len (shape_to_locations ) <= 1 :
649+ continue
650+ details = "; " .join (
651+ f"shape={ shape } at { locations [:5 ]} "
652+ + (f" (+{ len (locations ) - 5 } more)" if len (locations ) > 5 else "" )
653+ for shape , locations in sorted (
654+ shape_to_locations .items (), key = lambda kv : str (kv [0 ])
655+ )
656+ )
657+ report .add (
658+ "error" ,
659+ "GLOBAL_SHAPE_MISMATCH" ,
660+ f"global/{ global_name } " ,
661+ f"Global '{ global_name } ' has inconsistent shapes across samples: { details } " ,
662+ )
663+
628664 # Compare checksums from every checked sample to flag identical sample data.
629665 checksum_values = list (checksum_report .values ())
630666 if len (checksum_report ) != len (np .unique (checksum_values )):
0 commit comments