|
17 | 17 |
|
18 | 18 | import dpnp |
19 | 19 | from dpnp.dpnp_array import dpnp_array |
| 20 | +from dpnp.tests.helper import ( |
| 21 | + assert_array_equal, |
| 22 | + assert_dtype_allclose, |
| 23 | + generate_random_numpy_array, |
| 24 | + get_all_dtypes, |
| 25 | +) |
20 | 26 |
|
21 | 27 | from .helper import ( |
22 | 28 | get_abs_array, |
@@ -954,126 +960,162 @@ def test_indices(dimension, dtype, sparse): |
954 | 960 | assert_array_equal(Xnp, X) |
955 | 961 |
|
956 | 962 |
|
957 | | -@pytest.mark.parametrize( |
958 | | - "mask", |
959 | | - [ |
960 | | - [[True, False], [False, True]], |
961 | | - [[False, True], [True, False]], |
962 | | - [[False, False], [True, True]], |
963 | | - ], |
964 | | - ids=[ |
965 | | - "[[True, False], [False, True]]", |
966 | | - "[[False, True], [True, False]]", |
967 | | - "[[False, False], [True, True]]", |
968 | | - ], |
969 | | -) |
970 | | -@pytest.mark.parametrize( |
971 | | - "arr", |
972 | | - [[[0, 0], [0, 0]], [[1, 2], [1, 2]], [[1, 2], [3, 4]]], |
973 | | - ids=["[[0, 0], [0, 0]]", "[[1, 2], [1, 2]]", "[[1, 2], [3, 4]]"], |
974 | | -) |
975 | | -def test_putmask1(arr, mask): |
976 | | - a = numpy.array(arr) |
977 | | - ia = dpnp.array(a) |
978 | | - m = numpy.array(mask) |
979 | | - im = dpnp.array(m) |
980 | | - v = numpy.array([100, 200]) |
981 | | - iv = dpnp.array(v) |
982 | | - numpy.putmask(a, m, v) |
983 | | - dpnp.putmask(ia, im, iv) |
984 | | - assert_array_equal(a, ia) |
| 963 | +class TestPutmask: |
| 964 | + @pytest.mark.parametrize( |
| 965 | + "shape", |
| 966 | + [ |
| 967 | + (1,), |
| 968 | + (5,), |
| 969 | + (4, 3), |
| 970 | + (3, 3), |
| 971 | + (5, 3), |
| 972 | + (3, 4, 5), |
| 973 | + ], |
| 974 | + ) |
| 975 | + @pytest.mark.parametrize( |
| 976 | + "dt", get_all_dtypes(no_bool=True, no_float16=False) |
| 977 | + ) |
| 978 | + @pytest.mark.parametrize("order", ["C", "F"]) |
| 979 | + def test_putmask_scalar_values(self, shape, dt, order): |
| 980 | + a_np = generate_random_numpy_array(shape, order=order, dtype=dt) |
| 981 | + mask_np = a_np > 0 |
| 982 | + val = numpy.array(7, dtype=dt).item() |
985 | 983 |
|
| 984 | + a_dp = dpnp.array(a_np) |
| 985 | + mask_dp = dpnp.array(mask_np) |
986 | 986 |
|
987 | | -@pytest.mark.parametrize( |
988 | | - "vals", |
989 | | - [ |
990 | | - [100, 200], |
991 | | - [100, 200, 300, 400, 500, 600], |
992 | | - [100, 200, 300, 400, 500, 600, 800, 900], |
993 | | - ], |
994 | | - ids=[ |
995 | | - "[100, 200]", |
996 | | - "[100, 200, 300, 400, 500, 600]", |
997 | | - "[100, 200, 300, 400, 500, 600, 800, 900]", |
998 | | - ], |
999 | | -) |
1000 | | -@pytest.mark.parametrize( |
1001 | | - "mask", |
1002 | | - [ |
| 987 | + dpnp.putmask(a_dp, mask_dp, val) |
| 988 | + numpy.putmask(a_np, mask_np, val) |
| 989 | + |
| 990 | + assert_dtype_allclose(a_dp, a_np) |
| 991 | + |
| 992 | + @pytest.mark.parametrize( |
| 993 | + "shape", |
1003 | 994 | [ |
1004 | | - [[True, False], [False, True]], |
1005 | | - [[False, True], [True, False]], |
1006 | | - [[False, False], [True, True]], |
1007 | | - ] |
1008 | | - ], |
1009 | | - ids=[ |
1010 | | - "[[[True, False], [False, True]], [[False, True], [True, False]], [[False, False], [True, True]]]" |
1011 | | - ], |
1012 | | -) |
1013 | | -@pytest.mark.parametrize( |
1014 | | - "arr", |
1015 | | - [[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]], |
1016 | | - ids=["[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]"], |
1017 | | -) |
1018 | | -def test_putmask2(arr, mask, vals): |
1019 | | - a = numpy.array(arr) |
1020 | | - ia = dpnp.array(a) |
1021 | | - m = numpy.array(mask) |
1022 | | - im = dpnp.array(m) |
1023 | | - v = numpy.array(vals) |
1024 | | - iv = dpnp.array(v) |
1025 | | - numpy.putmask(a, m, v) |
1026 | | - dpnp.putmask(ia, im, iv) |
1027 | | - assert_array_equal(a, ia) |
| 995 | + (1,), |
| 996 | + (5,), |
| 997 | + (4, 3), |
| 998 | + (3, 3), |
| 999 | + (5, 3), |
| 1000 | + (3, 4, 5), |
| 1001 | + ], |
| 1002 | + ) |
| 1003 | + @pytest.mark.parametrize( |
| 1004 | + "dt", get_all_dtypes(no_bool=True, no_float16=False) |
| 1005 | + ) |
| 1006 | + @pytest.mark.parametrize("order", ["C", "F"]) |
| 1007 | + def test_putmask_same_shape(self, shape, dt, order): |
| 1008 | + a_np = generate_random_numpy_array(shape, dtype=dt, order=order) |
| 1009 | + mask_np = a_np > 0 |
| 1010 | + val_np = generate_random_numpy_array(shape, dtype=dt, order=order) |
1028 | 1011 |
|
| 1012 | + a_dp = dpnp.array(a_np, order=order) |
| 1013 | + mask_dp = dpnp.array(mask_np, order=order) |
| 1014 | + val_dp = dpnp.array(val_np, order=order) |
1029 | 1015 |
|
1030 | | -@pytest.mark.parametrize( |
1031 | | - "vals", |
1032 | | - [ |
1033 | | - [100, 200], |
1034 | | - [100, 200, 300, 400, 500, 600], |
1035 | | - [100, 200, 300, 400, 500, 600, 800, 900], |
1036 | | - ], |
1037 | | - ids=[ |
1038 | | - "[100, 200]", |
1039 | | - "[100, 200, 300, 400, 500, 600]", |
1040 | | - "[100, 200, 300, 400, 500, 600, 800, 900]", |
1041 | | - ], |
1042 | | -) |
1043 | | -@pytest.mark.parametrize( |
1044 | | - "mask", |
1045 | | - [ |
| 1016 | + dpnp.putmask(a_dp, mask_dp, val_dp) |
| 1017 | + numpy.putmask(a_np, mask_np, val_np) |
| 1018 | + |
| 1019 | + assert_dtype_allclose(a_dp, a_np) |
| 1020 | + |
| 1021 | + @pytest.mark.parametrize( |
| 1022 | + "a_shape,val_shape", |
1046 | 1023 | [ |
1047 | | - [[[False, False], [True, True]], [[True, True], [True, True]]], |
1048 | | - [[[False, False], [True, True]], [[False, False], [False, False]]], |
1049 | | - ] |
1050 | | - ], |
1051 | | - ids=[ |
1052 | | - "[[[[False, False], [True, True]], [[True, True], [True, True]]], [[[False, False], [True, True]], [[False, False], [False, False]]]]" |
1053 | | - ], |
1054 | | -) |
1055 | | -@pytest.mark.parametrize( |
1056 | | - "arr", |
1057 | | - [ |
| 1024 | + ((6,), (3,)), |
| 1025 | + ((6,), (7,)), |
| 1026 | + ((2, 3), (5,)), |
| 1027 | + ((6, 3), (5, 4)), |
| 1028 | + ((4, 3, 5), (8,)), |
| 1029 | + ((2, 4, 3), (5, 5, 2)), |
| 1030 | + ], |
| 1031 | + ) |
| 1032 | + @pytest.mark.parametrize( |
| 1033 | + "dt", get_all_dtypes(no_bool=True, no_float16=False) |
| 1034 | + ) |
| 1035 | + @pytest.mark.parametrize("order", ["C"]) # need to add "F" |
| 1036 | + def test_putmask_kernel(self, a_shape, val_shape, dt, order): |
| 1037 | + a_np = generate_random_numpy_array(a_shape, dtype=dt, order=order) |
| 1038 | + mask_np = a_np > 0 |
| 1039 | + val_np = generate_random_numpy_array(val_shape, dtype=dt, order=order) |
| 1040 | + |
| 1041 | + a_dp = dpnp.array(a_np, order=order) |
| 1042 | + mask_dp = dpnp.array(mask_np, order=order) |
| 1043 | + val_dp = dpnp.array(val_np, order=order) |
| 1044 | + |
| 1045 | + dpnp.putmask(a_dp, mask_dp, val_dp) |
| 1046 | + numpy.putmask(a_np, mask_np, val_np) |
| 1047 | + |
| 1048 | + assert_dtype_allclose(a_dp, a_np) |
| 1049 | + |
| 1050 | + # test_putmask_strided |
| 1051 | + |
| 1052 | + def test_putmask_mask_cast_to_bool(self): |
| 1053 | + a_np = generate_random_numpy_array((5, 5), dtype="f4") |
| 1054 | + mask_np = generate_random_numpy_array((5, 5), dtype="int64") |
| 1055 | + val_np = generate_random_numpy_array((7,), dtype="f4") |
| 1056 | + |
| 1057 | + a_dp = dpnp.array(a_np) |
| 1058 | + mask_dp = dpnp.array(mask_np) |
| 1059 | + val_dp = dpnp.array(val_np) |
| 1060 | + |
| 1061 | + dpnp.putmask(a_dp, mask_dp, val_dp) |
| 1062 | + numpy.putmask(a_np, mask_np, val_np) |
| 1063 | + |
| 1064 | + assert_dtype_allclose(a_dp, a_np) |
| 1065 | + |
| 1066 | + @pytest.mark.parametrize( |
| 1067 | + "dt", get_all_dtypes(no_bool=True, no_float16=False) |
| 1068 | + ) |
| 1069 | + @pytest.mark.parametrize("order", ["C", "F"]) |
| 1070 | + @pytest.mark.parametrize( |
| 1071 | + "shape", |
1058 | 1072 | [ |
1059 | | - [[[1, 2], [3, 4]], [[1, 2], [2, 1]]], |
1060 | | - [[[1, 3], [3, 1]], [[0, 1], [1, 3]]], |
1061 | | - ] |
1062 | | - ], |
1063 | | - ids=[ |
1064 | | - "[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]" |
1065 | | - ], |
1066 | | -) |
1067 | | -def test_putmask3(arr, mask, vals): |
1068 | | - a = numpy.array(arr) |
1069 | | - ia = dpnp.array(a) |
1070 | | - m = numpy.array(mask) |
1071 | | - im = dpnp.array(m) |
1072 | | - v = numpy.array(vals) |
1073 | | - iv = dpnp.array(v) |
1074 | | - numpy.putmask(a, m, v) |
1075 | | - dpnp.putmask(ia, im, iv) |
1076 | | - assert_array_equal(a, ia) |
| 1073 | + (0,), |
| 1074 | + (0, 3), |
| 1075 | + (2, 0), |
| 1076 | + (0, 2, 3), |
| 1077 | + ], |
| 1078 | + ) |
| 1079 | + @pytest.mark.parametrize( |
| 1080 | + "values_case", ["scalar", "same_shape", "diff_shape"] |
| 1081 | + ) |
| 1082 | + def test_putmask_empty(self, dt, order, shape, values_case): |
| 1083 | + a_np = numpy.empty(shape, dtype=dt, order=order) |
| 1084 | + mask_np = numpy.empty(shape, dtype=numpy.bool_, order=order) |
| 1085 | + |
| 1086 | + if values_case == "scalar": |
| 1087 | + val_np = numpy.asarray(1, dtype=dt).item() |
| 1088 | + val_dp = val_np |
| 1089 | + elif values_case == "same_shape": |
| 1090 | + val_np = numpy.empty(shape, dtype=dt, order=order) |
| 1091 | + val_dp = dpnp.array(val_np, order=order) |
| 1092 | + else: |
| 1093 | + # different shape |
| 1094 | + val_np = numpy.array([1, 2], dtype=dt) |
| 1095 | + val_dp = dpnp.array(val_np) |
| 1096 | + |
| 1097 | + a_dp = dpnp.array(a_np, order=order) |
| 1098 | + mask_dp = dpnp.array(mask_np, order=order) |
| 1099 | + |
| 1100 | + dpnp.putmask(a_dp, mask_dp, val_dp) |
| 1101 | + numpy.putmask(a_np, mask_np, val_np) |
| 1102 | + |
| 1103 | + assert_dtype_allclose(a_dp, a_np) |
| 1104 | + |
| 1105 | + def test_putmask_errors(self): |
| 1106 | + # shape mask mismatch |
| 1107 | + a = dpnp.arange(6).reshape(2, 3) |
| 1108 | + mask_bad = dpnp.ones((3, 2), dtype=dpnp.bool) |
| 1109 | + assert_raises(ValueError, dpnp.putmask, a, mask_bad, 1) |
| 1110 | + |
| 1111 | + # safe-cast error |
| 1112 | + a = dpnp.arange(10, dtype=dpnp.int32) |
| 1113 | + mask = a > 3 |
| 1114 | + val_f = dpnp.array([1.5, 2.5], dtype="f4") |
| 1115 | + assert_raises(TypeError, dpnp.putmask, a, mask, val_f) |
| 1116 | + |
| 1117 | + # values as list |
| 1118 | + assert_raises(TypeError, dpnp.putmask, a, mask, [1, 2, 3]) |
1077 | 1119 |
|
1078 | 1120 |
|
1079 | 1121 | @pytest.mark.parametrize("m", [None, 0, 1, 2, 3, 4]) |
|
0 commit comments