@@ -471,9 +471,10 @@ class StreamParser(object):
471471 Stream parser.
472472 """
473473
474- def __init__ (self , seq_names , stream ):
474+ def __init__ (self , seq_names , stream , use_lazy_data_integrity_checks = False ):
475475 self .seq_names = seq_names
476476 self .stream = stream
477+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
477478
478479 self .num_features = None
479480 self .feature_type = None # 1 for sparse, 2 for dense
@@ -518,8 +519,10 @@ def __init__(self, *args, **kwargs):
518519 if self .dtype is None :
519520 self .dtype = str (seq_data .dtype )
520521
521- assert seq_data .shape [1 ] == self .num_features
522- assert str (seq_data .dtype ) == self .dtype
522+ if self .use_lazy_data_integrity_checks :
523+ break
524+
525+ self .check_data_integrity (seq_data , s )
523526
524527 self .feature_type = 2
525528
@@ -528,7 +531,12 @@ def get_data(self, seq_name):
528531 :param str seq_name:
529532 :rtype: numpy.ndarray
530533 """
531- return self .stream ["data" ][seq_name ][...]
534+ data = self .stream ["data" ][seq_name ][...]
535+
536+ if self .use_lazy_data_integrity_checks :
537+ self .check_data_integrity (data , seq_name )
538+
539+ return data
532540
533541 def get_seq_length (self , seq_name ):
534542 """
@@ -537,6 +545,18 @@ def get_seq_length(self, seq_name):
537545 """
538546 return self .stream ["data" ][seq_name ].shape [0 ]
539547
548+ def check_data_integrity (self , data , seq_name ):
549+ """
550+ :param numpy.ndarray data
551+ :param str seq_name
552+ """
553+
554+ assert len (data .shape ) == 2 , f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
555+ assert (
556+ self .num_features == data .shape [1 ]
557+ ), f"feature dim mismatch in { seq_name } : { data .shape [1 ]} (should be { self .num_features } )"
558+ assert self .dtype == str (data .dtype ), f"dtype mismatch { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
559+
540560
541561class SparseStreamParser (StreamParser ):
542562 """
@@ -552,7 +572,11 @@ def __init__(self, *args, **kwargs):
552572
553573 if self .dtype is None :
554574 self .dtype = str (seq_data .dtype )
555- assert str (seq_data .dtype ) == self .dtype
575+
576+ if self .use_lazy_data_integrity_checks :
577+ break
578+
579+ self .check_data_integrity (seq_data , s )
556580
557581 self .num_features = self .stream ["feature_names" ].shape [0 ]
558582 self .feature_type = 1
@@ -562,7 +586,12 @@ def get_data(self, seq_name):
562586 :param str seq_name:
563587 :rtype: numpy.ndarray
564588 """
565- return self .stream ["data" ][seq_name ][:]
589+ data = self .stream ["data" ][seq_name ][:]
590+
591+ if self .use_lazy_data_integrity_checks :
592+ self .check_data_integrity (data , seq_name )
593+
594+ return data
566595
567596 def get_seq_length (self , seq_name ):
568597 """
@@ -571,6 +600,17 @@ def get_seq_length(self, seq_name):
571600 """
572601 return self .stream ["data" ][seq_name ].shape [0 ]
573602
603+ def check_data_integrity (self , data , seq_name ):
604+ """
605+ :param numpy.ndarray data
606+ :param str seq_name
607+ """
608+
609+ assert len (data .shape ) == 1 , f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
610+ assert self .dtype == str (
611+ data .dtype
612+ ), f"dtype mismatch in { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
613+
574614
575615class SegmentAlignmentStreamParser (StreamParser ):
576616 """
@@ -585,10 +625,11 @@ def __init__(self, *args, **kwargs):
585625
586626 if self .dtype is None :
587627 self .dtype = str (seq_data .dtype )
588- assert str (seq_data .dtype ) == self .dtype
589628
590- assert len (seq_data .shape ) == 2
591- assert seq_data .shape [1 ] == 2
629+ if self .use_lazy_data_integrity_checks :
630+ break
631+
632+ self .check_data_integrity (seq_data , s )
592633
593634 self .num_features = self .stream ["feature_names" ].shape [0 ]
594635 self .feature_type = 1
@@ -602,6 +643,9 @@ def get_data(self, seq_name):
602643 length = self .get_seq_length (seq_name ) // 2
603644 segments = self .stream ["data" ][seq_name ][:]
604645
646+ if self .use_lazy_data_integrity_checks :
647+ self .check_data_integrity (segments , seq_name )
648+
605649 alignment = numpy .zeros ((length , 2 ), dtype = self .dtype )
606650 num_segments = segments .shape [0 ]
607651 seg_end = 0
@@ -621,6 +665,22 @@ def get_seq_length(self, seq_name):
621665 """
622666 return 2 * sum (self .stream ["data" ][seq_name ][:, 1 ])
623667
668+ def check_data_integrity (self , data , seq_name ):
669+ """
670+ :param numpy.ndarray data
671+ :param str seq_name
672+ """
673+
674+ assert (
675+ len (data .shape ) == 2
676+ ), f"shape length mismatch in { seq_name } : { data .shape } (should be 2-dimensional)"
677+ assert (
678+ data .shape [1 ] == 2
679+ ), f"feature dim mismatch in { seq_name } : { data .shape [1 ]} (should be 2-dimensional)"
680+ assert self .dtype == str (
681+ data .dtype
682+ ), f"dtype mismatch in { seq_name } : { str (data .dtype )} (should be { self .dtype } )"
683+
624684
625685class NextGenHDFDataset (CachedDataset2 ):
626686 """
@@ -633,7 +693,7 @@ class NextGenHDFDataset(CachedDataset2):
633693 "segment_alignment" : SegmentAlignmentStreamParser ,
634694 }
635695
636- def __init__ (self , input_stream_name , files = None , ** kwargs ):
696+ def __init__ (self , input_stream_name , files = None , use_lazy_data_integrity_checks = False , ** kwargs ):
637697 """
638698 :param str input_stream_name:
639699 :param None|list[str] files:
@@ -649,6 +709,7 @@ def __init__(self, input_stream_name, files=None, **kwargs):
649709 self .file_indices = []
650710 self .seq_order = []
651711 self .all_parsers = collections .defaultdict (list )
712+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
652713
653714 if files :
654715 for fn in files :
@@ -684,7 +745,9 @@ def add_file(self, path):
684745 )
685746
686747 parsers = {
687- name : NextGenHDFDataset .parsers [stream .attrs ["parser" ]](norm_seqs , stream )
748+ name : NextGenHDFDataset .parsers [stream .attrs ["parser" ]](
749+ norm_seqs , stream , use_lazy_data_integrity_checks = self .use_lazy_data_integrity_checks
750+ )
688751 for name , stream in cur_file ["streams" ].items ()
689752 }
690753 for k , v in parsers .items ():
@@ -807,7 +870,15 @@ class SiameseHDFDataset(CachedDataset2):
807870 "segment_alignment" : SegmentAlignmentStreamParser ,
808871 }
809872
810- def __init__ (self , input_stream_name , seq_label_stream = "words" , class_distribution = None , files = None , ** kwargs ):
873+ def __init__ (
874+ self ,
875+ input_stream_name ,
876+ seq_label_stream = "words" ,
877+ class_distribution = None ,
878+ files = None ,
879+ use_lazy_data_integrity_checks = False ,
880+ ** kwargs ,
881+ ):
811882 """
812883 :param str input_stream_name: name of a feature stream
813884 :param str seq_label_stream: name of a stream with labels
@@ -833,6 +904,8 @@ def __init__(self, input_stream_name, seq_label_stream="words", class_distributi
833904 self .target_to_seqs = {} # (int) class_index -> (string) sequence_names
834905 self .curr_epoch_triplets = []
835906 self .targets_stream = seq_label_stream
907+ self .use_lazy_data_integrity_checks = use_lazy_data_integrity_checks
908+
836909 if files :
837910 for fn in files :
838911 self .add_file (fn )
@@ -872,7 +945,9 @@ def add_file(self, path):
872945 )
873946
874947 parsers = {
875- name : SiameseHDFDataset .parsers [stream .attrs ["parser" ]](norm_seqs , stream )
948+ name : SiameseHDFDataset .parsers [stream .attrs ["parser" ]](
949+ norm_seqs , stream , use_lazy_data_integrity_checks = self .use_lazy_data_integrity_checks
950+ )
876951 for name , stream in cur_file ["streams" ].items ()
877952 } # name - stream name (words, features, orth_features)
878953 for k , v in parsers .items ():
0 commit comments