File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -94,14 +94,15 @@ def load_or_create_state(
9494 M2 = torch .load (
9595 os .path .join (store_dir , "M2.pt" ), weights_only = True , map_location = device
9696 )
97- return RunningStatWelford (
97+ stat = RunningStatWelford (
9898 shape = mean .shape ,
99- dtype = dtype ,
99+ dtype = mean . dtype ,
100100 device = device ,
101- count = count ,
102- mean = mean ,
103- M2 = M2 ,
104101 )
102+ stat .count = count
103+ stat .mean = mean
104+ stat .M2 = M2
105+ return stat
105106 else :
106107 return RunningStatWelford (shape = shape , dtype = dtype , device = device )
107108
@@ -306,6 +307,12 @@ def std(self):
306307 def normalizer (self ):
307308 return ActivationNormalizer (self .mean , self .std )
308309
310+ @property
311+ def running_stats (self ):
312+ return RunningStatWelford .load_or_create_state (
313+ self ._cache_store_dir , shape = (self .config ["d_model" ],)
314+ )
315+
309316 def __len__ (self ):
310317 return self .config ["total_size" ]
311318
You can’t perform that action at this time.
0 commit comments