@@ -361,46 +361,6 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
361361 return first_element , return_iterable
362362
363363
364- class RunningMeanAndVar :
365- """Stores a running mean and variance using Wellford's algorithm."""
366-
367- def __init__ (
368- self ,
369- shape : Tuple [int , ...] = (),
370- device : Optional [str ] = None ,
371- ) -> None :
372- """Initialize blank mean, variance, count."""
373- self .running_mean = th .zeros (shape , device = device )
374- self .M2 = th .zeros (shape , device = device )
375- self .count = 0
376-
377- def update (self , batch : th .Tensor ) -> None :
378- """Update the mean and variance with a batch `x`."""
379- with th .no_grad ():
380- batch_mean = th .mean (batch , dim = 0 )
381- batch_var = th .var (batch , dim = 0 , unbiased = False )
382- batch_count = batch .shape [0 ]
383-
384- delta = batch_mean - self .running_mean
385- tot_count = self .count + batch_count
386- self .running_mean += delta * batch_count / tot_count
387-
388- self .M2 += batch_var * batch_count
389- self .M2 += th .square (delta ) * self .count * batch_count / tot_count
390-
391- self .count += batch_count
392-
393- @property
394- def var (self ) -> th .Tensor :
395- """Returns the unbiased estimate of the variances."""
396- return self .M2 / (self .count - 1 )
397-
398- @property
399- def std (self ) -> th .Tensor :
400- """Returns the unbiased estimate of the standard deviations."""
401- return np .sqrt (self .var )
402-
403-
404364def compute_state_entropy (
405365 obs : th .Tensor ,
406366 all_obs : th .Tensor ,
0 commit comments