File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -394,6 +394,13 @@ def one_hot(
394394) -> Array :
395395 if xp is None :
396396 xp = array_namespace (x )
397+ x_size = x .size
398+ if x_size is None :
399+ msg = "x must have a concrete size."
400+ raise TypeError (msg )
401+ if not xp .isdtype (x .dtype , "integral" ):
402+ msg = "x must have an integral dtype."
403+ raise TypeError (msg )
397404 if is_jax_namespace (xp ):
398405 assert is_jax_array (x )
399406 from jax .nn import one_hot
@@ -412,10 +419,6 @@ def one_hot(
412419 dtype = xp .empty (()).dtype # Default float dtype
413420 out = xp .zeros ((x .size , num_classes ), dtype = dtype )
414421 x_flattened = xp .reshape (x , (- 1 ,))
415- x_size = x .size
416- if x_size is None :
417- msg = "x must have a concrete size."
418- raise TypeError (msg )
419422 if is_numpy_namespace (xp ):
420423 at (out )[xp .arange (x_size ), x_flattened ].set (1 )
421424 else :
Original file line number Diff line number Diff line change @@ -455,7 +455,7 @@ def test_xp(self, xp: ModuleType):
455455)
456456class TestOneHot :
457457 @pytest .mark .parametrize ("n_dim" , range (4 ))
458- @pytest .mark .parametrize ("num_classes" , range ( 1 , 5 , 2 ) )
458+ @pytest .mark .parametrize ("num_classes" , [ 1 , 3 , 10 ] )
459459 def test_dims_and_classes (self , xp : ModuleType , n_dim : int , num_classes : int ):
460460 shape = tuple (range (2 , 2 + n_dim ))
461461 rng = np .random .default_rng (2347823 )
@@ -508,7 +508,7 @@ def test_axis(self, xp: ModuleType):
508508 xp_assert_equal (actual , expected )
509509
510510 def test_non_integer (self , xp : ModuleType ):
511- with pytest .raises (( TypeError , RuntimeError , IndexError , DeprecationWarning ) ):
511+ with pytest .raises (TypeError ):
512512 one_hot (xp .asarray ([1.0 ]), 3 )
513513
514514
You can’t perform that action at this time.
0 commit comments