2424
2525import org .apache .sysds .common .Types .ValueType ;
2626import org .apache .sysds .runtime .data .TensorBlock ;
27- import java .util .Arrays ;
27+
28+ import java .util .Arrays ;
2829
2930public class TransposeLinDataTest {
3031
3132 @ Test
32- public void testRightElem (){
33+ public void testRightElem () {
3334 int [] shape = {2 , 3 , 4 };
3435 TensorBlock tensor = TensorUtils .createArangeTensor (shape );
3536
36- Assert .assertArrayEquals (new int []{2 , 3 , 4 }, tensor .getDims ());
37- Assert .assertEquals (0.0 , tensor .get (new int []{0 , 0 , 0 }));
38- Assert .assertEquals (23.0 , tensor .get (new int []{1 , 2 , 3 }));
39- Assert .assertEquals (6.0 , tensor .get (new int []{0 , 1 , 2 }));
40- Assert .assertEquals (12.0 , tensor .get (new int []{1 , 0 , 0 }));
37+ Assert .assertArrayEquals (new int [] {2 , 3 , 4 }, tensor .getDims ());
38+ Assert .assertEquals (0.0 , tensor .get (new int [] {0 , 0 , 0 }));
39+ Assert .assertEquals (23.0 , tensor .get (new int [] {1 , 2 , 3 }));
40+ Assert .assertEquals (6.0 , tensor .get (new int [] {0 , 1 , 2 }));
41+ Assert .assertEquals (12.0 , tensor .get (new int [] {1 , 0 , 0 }));
4142 printTensor (tensor );
4243
43-
4444 int [] permutation = {1 , 0 , 2 };
45- TensorBlock outTensor = PermuteIt .permute (tensor , permutation );
46- printTensor (outTensor );
47-
48- Assert .assertArrayEquals (new int []{3 , 2 , 4 }, outTensor .getDims ());
49- Assert .assertEquals (0.0 , outTensor .get (new int []{0 ,0 , 0 }));
50- Assert .assertEquals (23.0 , outTensor .get (new int []{2 , 1 , 3 }));
51- Assert .assertEquals (12.0 , outTensor .get (new int []{0 , 1 , 0 }));
52- Assert .assertEquals (17.0 , outTensor .get (new int []{1 , 1 , 1 }));
53-
54- int [] second_permutation = {2 , 1 , 0 };
55- TensorBlock perm2Block = PermuteIt .permute (tensor , second_permutation );
56- printTensor (perm2Block );
57-
58- Assert .assertArrayEquals (new int []{4 , 3 , 2 }, perm2Block .getDims ());
59- Assert .assertEquals (0.0 , perm2Block .get (new int []{0 , 0 , 0 }));
60- Assert .assertEquals (12.0 , perm2Block .get (new int []{0 , 0 , 1 }));
61- Assert .assertEquals (11.0 , perm2Block .get (new int []{3 , 2 , 0 }));
62- Assert .assertEquals (23.0 , perm2Block .get (new int []{3 , 2 , 1 }));
45+ TensorBlock outTensor = PermuteIt .permute (tensor , permutation );
46+ printTensor (outTensor );
47+
48+ Assert .assertArrayEquals (new int [] {3 , 2 , 4 }, outTensor .getDims ());
49+ Assert .assertEquals (0.0 , outTensor .get (new int [] {0 , 0 , 0 }));
50+ Assert .assertEquals (23.0 , outTensor .get (new int [] {2 , 1 , 3 }));
51+ Assert .assertEquals (12.0 , outTensor .get (new int [] {0 , 1 , 0 }));
52+ Assert .assertEquals (17.0 , outTensor .get (new int [] {1 , 1 , 1 }));
53+
54+ int [] second_permutation = {2 , 1 , 0 };
55+ TensorBlock perm2Block = PermuteIt .permute (tensor , second_permutation );
56+ printTensor (perm2Block );
57+
58+ Assert .assertArrayEquals (new int [] {4 , 3 , 2 }, perm2Block .getDims ());
59+ Assert .assertEquals (0.0 , perm2Block .get (new int [] {0 , 0 , 0 }));
60+ Assert .assertEquals (12.0 , perm2Block .get (new int [] {0 , 0 , 1 }));
61+ Assert .assertEquals (11.0 , perm2Block .get (new int [] {3 , 2 , 0 }));
62+ Assert .assertEquals (23.0 , perm2Block .get (new int [] {3 , 2 , 1 }));
6363 }
6464
6565 public class TensorUtils {
6666
6767 public static TensorBlock createArangeTensor (int [] shape ) {
6868 TensorBlock tb = new TensorBlock (ValueType .FP64 , shape );
6969 tb .allocateBlock ();
70- double [] counter = { 0.0 };
70+ double [] counter = {0.0 };
7171 int [] currentIndices = new int [shape .length ];
72-
72+
7373 fillRecursively (tb , shape , 0 , currentIndices , counter );
74-
74+
7575 return tb ;
7676 }
7777
78- private static void fillRecursively (TensorBlock tb , int [] shape , int dim , int [] currentIndices , double [] counter ) {
79- if (dim == shape .length ) {
78+ private static void fillRecursively (TensorBlock tb , int [] shape , int dim , int [] currentIndices ,
79+ double [] counter ) {
80+ if (dim == shape .length ) {
8081 tb .set (currentIndices , counter [0 ]);
81- counter [0 ]++;
82+ counter [0 ]++;
8283 return ;
8384 }
8485
85- for (int i = 0 ; i < shape [dim ]; i ++) {
86+ for (int i = 0 ; i < shape [dim ]; i ++) {
8687 currentIndices [dim ] = i ;
8788
8889 fillRecursively (tb , shape , dim + 1 , currentIndices , counter );
@@ -91,42 +92,41 @@ private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[]
9192 }
9293
9394 public class PermuteIt {
94- public static TensorBlock permute (TensorBlock tensor , int [] permute_dims ) {
95- int anz_dims = tensor .getNumDims ();
95+ public static TensorBlock permute (TensorBlock tensor , int [] permute_dims ) {
96+ int anz_dims = tensor .getNumDims ();
9697 int [] dims = tensor .getDims ();
9798 ValueType tensorType = tensor .getValueType ();
9899
99- int [] out_shape = new int [anz_dims ];
100+ int [] out_shape = new int [anz_dims ];
100101
101- for (int idx = 0 ; idx < anz_dims ; idx ++){
102+ for (int idx = 0 ; idx < anz_dims ; idx ++) {
102103 out_shape [idx ] = dims [permute_dims [idx ]];
103104 }
104105
105- TensorBlock outTensor = new TensorBlock (tensorType , out_shape );
106+ TensorBlock outTensor = new TensorBlock (tensorType , out_shape );
106107 outTensor .allocateBlock ();
107108
108- int [] inIndex = new int [anz_dims ];
109- int [] outIndex = new int [anz_dims ];
109+ int [] inIndex = new int [anz_dims ];
110+ int [] outIndex = new int [anz_dims ];
110111
111- recursion (tensor , outTensor , permute_dims , dims , 0 , inIndex , outIndex );
112+ recursion (tensor , outTensor , permute_dims , dims , 0 , inIndex , outIndex );
112113 return outTensor ;
113114 }
114115
115- public static void recursion (TensorBlock inTensor , TensorBlock outTensor ,
116- int [] permutation , int [] inShape , int dim , int [] inIndex , int []outIndex )
117- {
118- if (dim == inShape .length ) {
119- for (int idx = 0 ; idx < permutation .length ; idx ++){
120- outIndex [idx ] = inIndex [permutation [idx ]];
116+ public static void recursion (TensorBlock inTensor , TensorBlock outTensor , int [] permutation , int [] inShape ,
117+ int dim , int [] inIndex , int [] outIndex ) {
118+ if (dim == inShape .length ) {
119+ for (int idx = 0 ; idx < permutation .length ; idx ++) {
120+ outIndex [idx ] = inIndex [permutation [idx ]];
121121 }
122122 double val = (double ) inTensor .get (inIndex );
123123 outTensor .set (outIndex , val );
124- return ;
124+ return ;
125125 }
126126
127- for (int idx = 0 ; idx < inShape [dim ]; idx ++){
128- inIndex [dim ] = idx ;
129- recursion (inTensor , outTensor , permutation , inShape , dim + 1 , inIndex , outIndex );
127+ for (int idx = 0 ; idx < inShape [dim ]; idx ++) {
128+ inIndex [dim ] = idx ;
129+ recursion (inTensor , outTensor , permutation , inShape , dim + 1 , inIndex , outIndex );
130130 }
131131 }
132132 }
@@ -135,42 +135,47 @@ public static void printTensor(TensorBlock tb) {
135135 StringBuilder sb = new StringBuilder ();
136136 int [] shape = tb .getDims ();
137137 int [] currentIndices = new int [shape .length ];
138-
138+
139139 sb .append ("Tensor(" ).append (Arrays .toString (shape )).append ("):\n " );
140140 printRecursive (tb , shape , 0 , currentIndices , sb , 0 );
141-
141+
142142 System .out .println (sb .toString ());
143143 }
144144
145- private static void printRecursive (TensorBlock tb , int [] shape , int dim , int [] indices , StringBuilder sb , int indent ) {
146- for (int k = 0 ; k < indent ; k ++) sb .append (" " );
145+ private static void printRecursive (TensorBlock tb , int [] shape , int dim , int [] indices , StringBuilder sb ,
146+ int indent ) {
147+ for (int k = 0 ; k < indent ; k ++)
148+ sb .append (" " );
147149
148150 sb .append ("[" );
149151
150- if (dim == shape .length - 1 ) {
151- for (int i = 0 ; i < shape [dim ]; i ++) {
152+ if (dim == shape .length - 1 ) {
153+ for (int i = 0 ; i < shape [dim ]; i ++) {
152154 indices [dim ] = i ;
153- double val = (double ) tb .get (indices );
154- sb .append (String .format ("%.1f" , val ));
155- if (i < shape [dim ] - 1 ) sb .append (", " );
155+ double val = (double ) tb .get (indices );
156+ sb .append (String .format ("%.1f" , val ));
157+ if (i < shape [dim ] - 1 )
158+ sb .append (", " );
156159 }
157160 sb .append ("]" );
158- }
161+ }
159162
160163 else {
161164 sb .append ("\n " );
162- for (int i = 0 ; i < shape [dim ]; i ++) {
165+ for (int i = 0 ; i < shape [dim ]; i ++) {
163166 indices [dim ] = i ;
164167 printRecursive (tb , shape , dim + 1 , indices , sb , indent + 2 );
165-
166- if (i < shape [dim ] - 1 ) {
168+
169+ if (i < shape [dim ] - 1 ) {
167170 sb .append ("," );
168- sb .append ("\n " );
169- if (shape .length - dim > 2 ) sb .append ("\n " );
171+ sb .append ("\n " );
172+ if (shape .length - dim > 2 )
173+ sb .append ("\n " );
170174 }
171175 }
172- sb .append ("\n " );
173- for (int k = 0 ; k < indent ; k ++) sb .append (" " );
176+ sb .append ("\n " );
177+ for (int k = 0 ; k < indent ; k ++)
178+ sb .append (" " );
174179 sb .append ("]" );
175180 }
176181 }
0 commit comments