@@ -110,8 +110,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
110110 self .patch_size = patch_size
111111 self .patch_method = patch_method
112112
113- self .register_buffer ("wavelets" , _WAVELETS [patch_method ], persistent = False )
114- self .register_buffer ("_arange" , torch .arange (_WAVELETS [patch_method ].shape [0 ]), persistent = False )
113+ wavelets = _WAVELETS .get (patch_method ).clone ()
114+ arange = torch .arange (wavelets .shape [0 ])
115+
116+ self .register_buffer ("wavelets" , wavelets , persistent = False )
117+ self .register_buffer ("_arange" , arange , persistent = False )
115118
116119 def _dwt (self , hidden_states : torch .Tensor , mode : str = "reflect" , rescale = False ) -> torch .Tensor :
117120 dtype = hidden_states .dtype
@@ -185,12 +188,11 @@ def __init__(self, patch_size: int = 1, patch_method: str = "haar"):
185188 self .patch_size = patch_size
186189 self .patch_method = patch_method
187190
188- self .register_buffer ("wavelets" , _WAVELETS [patch_method ], persistent = False )
189- self .register_buffer (
190- "_arange" ,
191- torch .arange (_WAVELETS [patch_method ].shape [0 ]),
192- persistent = False ,
193- )
191+ wavelets = _WAVELETS .get (patch_method ).clone ()
192+ arange = torch .arange (wavelets .shape [0 ])
193+
194+ self .register_buffer ("wavelets" , wavelets , persistent = False )
195+ self .register_buffer ("_arange" , arange , persistent = False )
194196
195197 def _idwt (self , hidden_states : torch .Tensor , rescale : bool = False ) -> torch .Tensor :
196198 device = hidden_states .device
0 commit comments