@@ -68,6 +68,12 @@ def __init__(self, parent, data, node_name):
6868
6969 # data - signal on level 0, coeffs on higher levels
7070 self .data = data
71+ # Need to retain original data size/shape so we can trim any excess
72+ # boundary coefficients from the inverse transform.
73+ if self .data is None :
74+ self ._data_shape = None
75+ else :
76+ self ._data_shape = np .asarray (data ).shape
7177
7278 self ._init_subnodes ()
7379
@@ -436,6 +442,9 @@ def _reconstruct(self, update):
436442 " from subnodes." )
437443 else :
438444 rec = idwt (data_a , data_d , self .wavelet , self .mode )
445+ if self ._data_shape is not None and (
446+ rec .shape != self ._data_shape ):
447+ rec = rec [tuple ([slice (sz ) for sz in self ._data_shape ])]
439448 if update :
440449 self .data = rec
441450 return rec
@@ -504,6 +513,9 @@ def _reconstruct(self, update):
504513 else :
505514 coeffs = data_ll , (data_hl , data_lh , data_hh )
506515 rec = idwt2 (coeffs , self .wavelet , self .mode )
516+ if self ._data_shape is not None and (
517+ rec .shape != self ._data_shape ):
518+ rec = rec [tuple ([slice (sz ) for sz in self ._data_shape ])]
507519 if update :
508520 self .data = rec
509521 return rec
@@ -568,8 +580,6 @@ def reconstruct(self, update=True):
568580 """
569581 if self .has_any_subnode :
570582 data = super (WaveletPacket , self ).reconstruct (update )
571- if self .data_size is not None and len (data ) > self .data_size :
572- data = data [:self .data_size ]
573583 if update :
574584 self .data = data
575585 return data
@@ -669,8 +679,6 @@ def reconstruct(self, update=True):
669679 """
670680 if self .has_any_subnode :
671681 data = super (WaveletPacket2D , self ).reconstruct (update )
672- if self .data_size is not None and (data .shape != self .data_size ):
673- data = data [:self .data_size [0 ], :self .data_size [1 ]]
674682 if update :
675683 self .data = data
676684 return data
0 commit comments