@@ -362,10 +362,10 @@ def get_first_iter_element(iterable: Iterable[T]) -> Tuple[T, Iterable[T]]:
362362
363363
364364def compute_state_entropy (
365- obs : np . ndarray ,
366- all_obs : np . ndarray ,
365+ obs : th . Tensor ,
366+ all_obs : th . Tensor ,
367367 k : int ,
368- ) -> np . ndarray :
368+ ) -> th . Tensor :
369369 """Compute the state entropy given by KNN distance.
370370
371371 Args:
@@ -379,19 +379,15 @@ def compute_state_entropy(
379379 assert obs .shape [1 :] == all_obs .shape [1 :]
380380 with th .no_grad ():
381381 non_batch_dimensions = tuple (range (2 , len (obs .shape ) + 1 ))
382- distances_tensor = np .linalg .norm (
382+ distances_tensor = th .linalg .vector_norm (
383383 obs [:, None ] - all_obs [None , :],
384- axis = non_batch_dimensions ,
384+ dim = non_batch_dimensions ,
385385 ord = 2 ,
386386 )
387387
388388 # Note that we take the k+1'th value because the closest neighbor to
389389 # a point is itself, which we want to skip.
390- knn_dists = kth_value (distances_tensor , k + 1 )
390+ assert distances_tensor .shape [- 1 ] > k
391+ knn_dists = th .kthvalue (distances_tensor , k = k + 1 , dim = 1 ).values
391392 state_entropy = knn_dists
392- return np .expand_dims (state_entropy , axis = 1 )
393-
394-
395- def kth_value (x : np .ndarray , k : int ):
396- assert k > 0
397- return np .partition (x , k - 1 , axis = - 1 )[..., k - 1 ]
393+ return state_entropy .unsqueeze (1 )
0 commit comments